From aa089e81082c05433687a98005665087124cba0a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 15:23:06 +1100 Subject: [PATCH 01/67] tidy(nodes): move all field things to fields.py Unfortunately, this is necessary to prevent circular imports at runtime. --- invokeai/app/api/routers/images.py | 2 +- invokeai/app/api_app.py | 3 +- invokeai/app/invocations/baseinvocation.py | 453 +--------------- invokeai/app/invocations/collections.py | 3 +- invokeai/app/invocations/compel.py | 6 +- .../controlnet_image_processors.py | 6 +- invokeai/app/invocations/cv.py | 3 +- invokeai/app/invocations/facetools.py | 4 +- invokeai/app/invocations/fields.py | 501 ++++++++++++++++++ invokeai/app/invocations/image.py | 5 +- invokeai/app/invocations/infill.py | 3 +- invokeai/app/invocations/ip_adapter.py | 5 +- invokeai/app/invocations/latent.py | 7 +- invokeai/app/invocations/math.py | 4 +- invokeai/app/invocations/metadata.py | 6 +- invokeai/app/invocations/model.py | 5 +- invokeai/app/invocations/noise.py | 4 +- invokeai/app/invocations/onnx.py | 16 +- invokeai/app/invocations/param_easing.py | 3 +- invokeai/app/invocations/primitives.py | 6 +- invokeai/app/invocations/prompt.py | 3 +- invokeai/app/invocations/sdxl.py | 6 +- invokeai/app/invocations/strings.py | 4 +- invokeai/app/invocations/t2i_adapter.py | 5 +- invokeai/app/invocations/tiles.py | 5 +- invokeai/app/invocations/upscale.py | 3 +- .../services/image_files/image_files_base.py | 2 +- .../services/image_files/image_files_disk.py | 2 +- .../image_records/image_records_base.py | 2 +- .../image_records/image_records_sqlite.py | 2 +- invokeai/app/services/images/images_base.py | 2 +- .../app/services/images/images_default.py | 2 +- invokeai/app/services/shared/graph.py | 5 +- invokeai/app/shared/fields.py | 67 --- invokeai/app/shared/models.py | 2 +- tests/aa_nodes/test_nodes.py | 3 +- 36 files changed, 552 insertions(+), 608 deletions(-) create mode 100644 invokeai/app/invocations/fields.py delete mode 100644 invokeai/app/shared/fields.py diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 125896b8d3a..cc60ad1be83 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -8,7 +8,7 @@ from PIL import Image from pydantic import BaseModel, Field, ValidationError -from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator +from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 6294083d0e1..f48074de7c7 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -6,6 +6,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.version.invokeai_version import __version__ +from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra from .services.config import InvokeAIAppConfig app_config = InvokeAIAppConfig.get_config() @@ -57,8 +58,6 @@ from .api.sockets import SocketIO from .invocations.baseinvocation import ( BaseInvocation, - InputFieldJSONSchemaExtra, - OutputFieldJSONSchemaExtra, UIConfigBase, ) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index d9e0c7ba0d2..395d5e98707 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -12,10 +12,11 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast import semver -from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model -from pydantic.fields import FieldInfo, _Unset +from pydantic import BaseModel, ConfigDict, Field, create_model +from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined +from invokeai.app.invocations.fields import FieldKind, Input from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.shared.fields import FieldDescriptions @@ -52,393 +53,6 @@ class Classification(str, Enum, metaclass=MetaEnum): Prototype = "prototype" -class Input(str, Enum, metaclass=MetaEnum): - """ - The type of input a field accepts. - - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ - are instantiated. - - `Input.Connection`: The field must have its value provided by a connection. - - `Input.Any`: The field may have its value provided either directly or by a connection. - """ - - Connection = "connection" - Direct = "direct" - Any = "any" - - -class FieldKind(str, Enum, metaclass=MetaEnum): - """ - The kind of field. - - `Input`: An input field on a node. - - `Output`: An output field on a node. - - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is - one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name - "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, - allowing "metadata" for that field. - - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, - but which are used to store information about the node. For example, the `id` and `type` fields are node - attributes. - - The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app - startup, and when generating the OpenAPI schema for the workflow editor. - """ - - Input = "input" - Output = "output" - Internal = "internal" - NodeAttribute = "node_attribute" - - -class UIType(str, Enum, metaclass=MetaEnum): - """ - Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. - - - Model Fields - The most common node-author-facing use will be for model fields. Internally, there is no difference - between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the - base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that - the field is an SDXL main model field. - - - Any Field - We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to - indicate that the field accepts any type. Use with caution. This cannot be used on outputs. - - - Scheduler Field - Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. - - - Internal Fields - Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate - handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These - should not be used by node authors. - - - DEPRECATED Fields - These types are deprecated and should not be used by node authors. A warning will be logged if one is - used, and the type will be ignored. They are included here for backwards compatibility. - """ - - # region Model Field Types - SDXLMainModel = "SDXLMainModelField" - SDXLRefinerModel = "SDXLRefinerModelField" - ONNXModel = "ONNXModelField" - VaeModel = "VAEModelField" - LoRAModel = "LoRAModelField" - ControlNetModel = "ControlNetModelField" - IPAdapterModel = "IPAdapterModelField" - # endregion - - # region Misc Field Types - Scheduler = "SchedulerField" - Any = "AnyField" - # endregion - - # region Internal Field Types - _Collection = "CollectionField" - _CollectionItem = "CollectionItemField" - # endregion - - # region DEPRECATED - Boolean = "DEPRECATED_Boolean" - Color = "DEPRECATED_Color" - Conditioning = "DEPRECATED_Conditioning" - Control = "DEPRECATED_Control" - Float = "DEPRECATED_Float" - Image = "DEPRECATED_Image" - Integer = "DEPRECATED_Integer" - Latents = "DEPRECATED_Latents" - String = "DEPRECATED_String" - BooleanCollection = "DEPRECATED_BooleanCollection" - ColorCollection = "DEPRECATED_ColorCollection" - ConditioningCollection = "DEPRECATED_ConditioningCollection" - ControlCollection = "DEPRECATED_ControlCollection" - FloatCollection = "DEPRECATED_FloatCollection" - ImageCollection = "DEPRECATED_ImageCollection" - IntegerCollection = "DEPRECATED_IntegerCollection" - LatentsCollection = "DEPRECATED_LatentsCollection" - StringCollection = "DEPRECATED_StringCollection" - BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" - ColorPolymorphic = "DEPRECATED_ColorPolymorphic" - ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" - ControlPolymorphic = "DEPRECATED_ControlPolymorphic" - FloatPolymorphic = "DEPRECATED_FloatPolymorphic" - ImagePolymorphic = "DEPRECATED_ImagePolymorphic" - IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" - LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" - StringPolymorphic = "DEPRECATED_StringPolymorphic" - MainModel = "DEPRECATED_MainModel" - UNet = "DEPRECATED_UNet" - Vae = "DEPRECATED_Vae" - CLIP = "DEPRECATED_CLIP" - Collection = "DEPRECATED_Collection" - CollectionItem = "DEPRECATED_CollectionItem" - Enum = "DEPRECATED_Enum" - WorkflowField = "DEPRECATED_WorkflowField" - IsIntermediate = "DEPRECATED_IsIntermediate" - BoardField = "DEPRECATED_BoardField" - MetadataItem = "DEPRECATED_MetadataItem" - MetadataItemCollection = "DEPRECATED_MetadataItemCollection" - MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" - MetadataDict = "DEPRECATED_MetadataDict" - # endregion - - -class UIComponent(str, Enum, metaclass=MetaEnum): - """ - The type of UI component to use for a field, used to override the default components, which are - inferred from the field type. - """ - - None_ = "none" - Textarea = "textarea" - Slider = "slider" - - -class InputFieldJSONSchemaExtra(BaseModel): - """ - Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, - and by the workflow editor during schema parsing and UI rendering. - """ - - input: Input - orig_required: bool - field_kind: FieldKind - default: Optional[Any] = None - orig_default: Optional[Any] = None - ui_hidden: bool = False - ui_type: Optional[UIType] = None - ui_component: Optional[UIComponent] = None - ui_order: Optional[int] = None - ui_choice_labels: Optional[dict[str, str]] = None - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - ) - - -class OutputFieldJSONSchemaExtra(BaseModel): - """ - Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor - during schema parsing and UI rendering. - """ - - field_kind: FieldKind - ui_hidden: bool - ui_type: Optional[UIType] - ui_order: Optional[int] - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - ) - - -def InputField( - # copied from pydantic's Field - # TODO: Can we support default_factory? - default: Any = _Unset, - default_factory: Callable[[], Any] | None = _Unset, - title: str | None = _Unset, - description: str | None = _Unset, - pattern: str | None = _Unset, - strict: bool | None = _Unset, - gt: float | None = _Unset, - ge: float | None = _Unset, - lt: float | None = _Unset, - le: float | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - # custom - input: Input = Input.Any, - ui_type: Optional[UIType] = None, - ui_component: Optional[UIComponent] = None, - ui_hidden: bool = False, - ui_order: Optional[int] = None, - ui_choice_labels: Optional[dict[str, str]] = None, -) -> Any: - """ - Creates an input field for an invocation. - - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ - that adds a few extra parameters to support graph execution and the node editor UI. - - :param Input input: [Input.Any] The kind of input this field requires. \ - `Input.Direct` means a value must be provided on instantiation. \ - `Input.Connection` means the value must be provided by a connection. \ - `Input.Any` means either will do. - - :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ - In some situations, the field's type is not enough to infer the correct UI type. \ - For example, model selection fields should render a dropdown UI component to select a model. \ - Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ - `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ - `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - - :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ - The UI will always render a suitable component, but sometimes you want something different than the default. \ - For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ - For this case, you could provide `UIComponent.Textarea`. - - :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. - - :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. - - :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. - """ - - json_schema_extra_ = InputFieldJSONSchemaExtra( - input=input, - ui_type=ui_type, - ui_component=ui_component, - ui_hidden=ui_hidden, - ui_order=ui_order, - ui_choice_labels=ui_choice_labels, - field_kind=FieldKind.Input, - orig_required=True, - ) - - """ - There is a conflict between the typing of invocation definitions and the typing of an invocation's - `invoke()` function. - - On instantiation of a node, the invocation definition is used to create the python class. At this time, - any number of fields may be optional, because they may be provided by connections. - - On calling of `invoke()`, however, those fields may be required. - - For example, consider an ResizeImageInvocation with an `image: ImageField` field. - - `image` is required during the call to `invoke()`, but when the python class is instantiated, - the field may not be present. This is fine, because that image field will be provided by a - connection from an ancestor node, which outputs an image. - - This means we want to type the `image` field as optional for the node class definition, but required - for the `invoke()` function. - - If we use `typing.Optional` in the node class definition, the field will be typed as optional in the - `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or - any static type analysis tools will complain. - - To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, - but secretly make them optional in `InputField()`. We also store the original required bool and/or default - value. When we call `invoke()`, we use this stored information to do an additional check on the class. - """ - - if default_factory is not _Unset and default_factory is not None: - default = default_factory() - logger.warn('"default_factory" is not supported, calling it now to set "default"') - - # These are the args we may wish pass to the pydantic `Field()` function - field_args = { - "default": default, - "title": title, - "description": description, - "pattern": pattern, - "strict": strict, - "gt": gt, - "ge": ge, - "lt": lt, - "le": le, - "multiple_of": multiple_of, - "allow_inf_nan": allow_inf_nan, - "max_digits": max_digits, - "decimal_places": decimal_places, - "min_length": min_length, - "max_length": max_length, - } - - # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected - provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} - - # Because we are manually making fields optional, we need to store the original required bool for reference later - json_schema_extra_.orig_required = default is PydanticUndefined - - # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one - if input is Input.Any or input is Input.Connection: - default_ = None if default is PydanticUndefined else default - provided_args.update({"default": default_}) - if default is not PydanticUndefined: - # Before invoking, we'll check for the original default value and set it on the field if the field has no value - json_schema_extra_.default = default - json_schema_extra_.orig_default = default - elif default is not PydanticUndefined: - default_ = default - provided_args.update({"default": default_}) - json_schema_extra_.orig_default = default_ - - return Field( - **provided_args, - json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), - ) - - -def OutputField( - # copied from pydantic's Field - default: Any = _Unset, - title: str | None = _Unset, - description: str | None = _Unset, - pattern: str | None = _Unset, - strict: bool | None = _Unset, - gt: float | None = _Unset, - ge: float | None = _Unset, - lt: float | None = _Unset, - le: float | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - # custom - ui_type: Optional[UIType] = None, - ui_hidden: bool = False, - ui_order: Optional[int] = None, -) -> Any: - """ - Creates an output field for an invocation output. - - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ - that adds a few extra parameters to support graph execution and the node editor UI. - - :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ - In some situations, the field's type is not enough to infer the correct UI type. \ - For example, model selection fields should render a dropdown UI component to select a model. \ - Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ - `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ - `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - - :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ - - :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ - """ - return Field( - default=default, - title=title, - description=description, - pattern=pattern, - strict=strict, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - allow_inf_nan=allow_inf_nan, - max_digits=max_digits, - decimal_places=decimal_places, - min_length=min_length, - max_length=max_length, - json_schema_extra=OutputFieldJSONSchemaExtra( - ui_type=ui_type, - ui_hidden=ui_hidden, - ui_order=ui_order, - field_kind=FieldKind.Output, - ).model_dump(exclude_none=True), - ) - - class UIConfigBase(BaseModel): """ Provides additional node configuration to the UI. @@ -460,33 +74,6 @@ class UIConfigBase(BaseModel): ) -class InvocationContext: - """Initialized and provided to on execution of invocations.""" - - services: InvocationServices - graph_execution_state_id: str - queue_id: str - queue_item_id: int - queue_batch_id: str - workflow: Optional[WorkflowWithoutID] - - def __init__( - self, - services: InvocationServices, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - workflow: Optional[WorkflowWithoutID], - ): - self.services = services - self.graph_execution_state_id = graph_execution_state_id - self.queue_id = queue_id - self.queue_item_id = queue_item_id - self.queue_batch_id = queue_batch_id - self.workflow = workflow - - class BaseInvocationOutput(BaseModel): """ Base class for all invocation outputs. @@ -926,37 +513,3 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper - - -class MetadataField(RootModel): - """ - Pydantic model for metadata with custom root of type dict[str, Any]. - Metadata is stored without a strict schema. - """ - - root: dict[str, Any] = Field(description="The metadata") - - -MetadataFieldValidator = TypeAdapter(MetadataField) - - -class WithMetadata(BaseModel): - metadata: Optional[MetadataField] = Field( - default=None, - description=FieldDescriptions.metadata, - json_schema_extra=InputFieldJSONSchemaExtra( - field_kind=FieldKind.Internal, - input=Input.Connection, - orig_required=False, - ).model_dump(exclude_none=True), - ) - - -class WithWorkflow: - workflow = None - - def __init_subclass__(cls) -> None: - logger.warn( - f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." - ) - super().__init_subclass__() diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 4c7b6f94cd4..d35a9d79c74 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -7,7 +7,8 @@ from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField @invocation( diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 49c62cff564..b386aef2cbe 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,8 +5,8 @@ from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, @@ -20,11 +20,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1f9342985a0..9b652b8eee9 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,10 +25,10 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector @@ -36,11 +36,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index cb6828d21ac..5865338e192 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -8,7 +8,8 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0") diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index e0c89b4de5a..13f1066ec3e 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -13,13 +13,11 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) +from invokeai.app.invocations.fields import InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py new file mode 100644 index 00000000000..0cce8e3c6b5 --- /dev/null +++ b/invokeai/app/invocations/fields.py @@ -0,0 +1,501 @@ +from enum import Enum +from typing import Any, Callable, Optional + +from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter +from pydantic.fields import _Unset +from pydantic_core import PydanticUndefined + +from invokeai.app.util.metaenum import MetaEnum +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger() + + +class UIType(str, Enum, metaclass=MetaEnum): + """ + Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. + + - Model Fields + The most common node-author-facing use will be for model fields. Internally, there is no difference + between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + the field is an SDXL main model field. + + - Any Field + We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + + - Scheduler Field + Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. + + - Internal Fields + Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + should not be used by node authors. + + - DEPRECATED Fields + These types are deprecated and should not be used by node authors. A warning will be logged if one is + used, and the type will be ignored. They are included here for backwards compatibility. + """ + + # region Model Field Types + SDXLMainModel = "SDXLMainModelField" + SDXLRefinerModel = "SDXLRefinerModelField" + ONNXModel = "ONNXModelField" + VaeModel = "VAEModelField" + LoRAModel = "LoRAModelField" + ControlNetModel = "ControlNetModelField" + IPAdapterModel = "IPAdapterModelField" + # endregion + + # region Misc Field Types + Scheduler = "SchedulerField" + Any = "AnyField" + # endregion + + # region Internal Field Types + _Collection = "CollectionField" + _CollectionItem = "CollectionItemField" + # endregion + + # region DEPRECATED + Boolean = "DEPRECATED_Boolean" + Color = "DEPRECATED_Color" + Conditioning = "DEPRECATED_Conditioning" + Control = "DEPRECATED_Control" + Float = "DEPRECATED_Float" + Image = "DEPRECATED_Image" + Integer = "DEPRECATED_Integer" + Latents = "DEPRECATED_Latents" + String = "DEPRECATED_String" + BooleanCollection = "DEPRECATED_BooleanCollection" + ColorCollection = "DEPRECATED_ColorCollection" + ConditioningCollection = "DEPRECATED_ConditioningCollection" + ControlCollection = "DEPRECATED_ControlCollection" + FloatCollection = "DEPRECATED_FloatCollection" + ImageCollection = "DEPRECATED_ImageCollection" + IntegerCollection = "DEPRECATED_IntegerCollection" + LatentsCollection = "DEPRECATED_LatentsCollection" + StringCollection = "DEPRECATED_StringCollection" + BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" + ColorPolymorphic = "DEPRECATED_ColorPolymorphic" + ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" + ControlPolymorphic = "DEPRECATED_ControlPolymorphic" + FloatPolymorphic = "DEPRECATED_FloatPolymorphic" + ImagePolymorphic = "DEPRECATED_ImagePolymorphic" + IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" + LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" + StringPolymorphic = "DEPRECATED_StringPolymorphic" + MainModel = "DEPRECATED_MainModel" + UNet = "DEPRECATED_UNet" + Vae = "DEPRECATED_Vae" + CLIP = "DEPRECATED_CLIP" + Collection = "DEPRECATED_Collection" + CollectionItem = "DEPRECATED_CollectionItem" + Enum = "DEPRECATED_Enum" + WorkflowField = "DEPRECATED_WorkflowField" + IsIntermediate = "DEPRECATED_IsIntermediate" + BoardField = "DEPRECATED_BoardField" + MetadataItem = "DEPRECATED_MetadataItem" + MetadataItemCollection = "DEPRECATED_MetadataItemCollection" + MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" + MetadataDict = "DEPRECATED_MetadataDict" + + +class UIComponent(str, Enum, metaclass=MetaEnum): + """ + The type of UI component to use for a field, used to override the default components, which are + inferred from the field type. + """ + + None_ = "none" + Textarea = "textarea" + Slider = "slider" + + +class FieldDescriptions: + denoising_start = "When to start denoising, expressed a percentage of total steps" + denoising_end = "When to stop denoising, expressed a percentage of total steps" + cfg_scale = "Classifier-Free Guidance scale" + cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" + scheduler = "Scheduler to use during inference" + positive_cond = "Positive conditioning tensor" + negative_cond = "Negative conditioning tensor" + noise = "Noise tensor" + clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" + unet = "UNet (scheduler, LoRAs)" + vae = "VAE" + cond = "Conditioning tensor" + controlnet_model = "ControlNet model to load" + vae_model = "VAE model to load" + lora_model = "LoRA model to load" + main_model = "Main model (UNet, VAE, CLIP) to load" + sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" + sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" + onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" + lora_weight = "The weight at which the LoRA is applied to each model" + compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" + raw_prompt = "Raw prompt text (no parsing)" + sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" + skipped_layers = "Number of layers to skip in text encoder" + seed = "Seed for random number generation" + steps = "Number of steps to run" + width = "Width of output (px)" + height = "Height of output (px)" + control = "ControlNet(s) to apply" + ip_adapter = "IP-Adapter to apply" + t2i_adapter = "T2I-Adapter(s) to apply" + denoised_latents = "Denoised latents tensor" + latents = "Latents tensor" + strength = "Strength of denoising (proportional to steps)" + metadata = "Optional metadata to be saved with the image" + metadata_collection = "Collection of Metadata" + metadata_item_polymorphic = "A single metadata item or collection of metadata items" + metadata_item_label = "Label for this metadata item" + metadata_item_value = "The value for this metadata item (may be any type)" + workflow = "Optional workflow to be saved with the image" + interp_mode = "Interpolation mode" + torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" + fp32 = "Whether or not to use full float32 precision" + precision = "Precision to use" + tiled = "Processing using overlapping tiles (reduce memory consumption)" + detect_res = "Pixel resolution for detection" + image_res = "Pixel resolution for output image" + safe_mode = "Whether or not to use safe mode" + scribble_mode = "Whether or not to use scribble mode" + scale_factor = "The factor by which to scale" + blend_alpha = ( + "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." + ) + num_1 = "The first number" + num_2 = "The second number" + mask = "The mask to use for the operation" + board = "The board to save the image to" + image = "The image to process" + tile_size = "Tile size" + inclusive_low = "The inclusive low value" + exclusive_high = "The exclusive high value" + decimal_places = "The number of decimal places to round to" + freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." + freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." + + +class MetadataField(RootModel): + """ + Pydantic model for metadata with custom root of type dict[str, Any]. + Metadata is stored without a strict schema. + """ + + root: dict[str, Any] = Field(description="The metadata") + + +MetadataFieldValidator = TypeAdapter(MetadataField) + + +class Input(str, Enum, metaclass=MetaEnum): + """ + The type of input a field accepts. + - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ + are instantiated. + - `Input.Connection`: The field must have its value provided by a connection. + - `Input.Any`: The field may have its value provided either directly or by a connection. + """ + + Connection = "connection" + Direct = "direct" + Any = "any" + + +class FieldKind(str, Enum, metaclass=MetaEnum): + """ + The kind of field. + - `Input`: An input field on a node. + - `Output`: An output field on a node. + - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + allowing "metadata" for that field. + - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + but which are used to store information about the node. For example, the `id` and `type` fields are node + attributes. + + The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + startup, and when generating the OpenAPI schema for the workflow editor. + """ + + Input = "input" + Output = "output" + Internal = "internal" + NodeAttribute = "node_attribute" + + +class InputFieldJSONSchemaExtra(BaseModel): + """ + Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + and by the workflow editor during schema parsing and UI rendering. + """ + + input: Input + orig_required: bool + field_kind: FieldKind + default: Optional[Any] = None + orig_default: Optional[Any] = None + ui_hidden: bool = False + ui_type: Optional[UIType] = None + ui_component: Optional[UIComponent] = None + ui_order: Optional[int] = None + ui_choice_labels: Optional[dict[str, str]] = None + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + ) + + +class WithMetadata(BaseModel): + metadata: Optional[MetadataField] = Field( + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class WithWorkflow: + workflow = None + + def __init_subclass__(cls) -> None: + logger.warn( + f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." + ) + super().__init_subclass__() + + +class OutputFieldJSONSchemaExtra(BaseModel): + """ + Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + during schema parsing and UI rendering. + """ + + field_kind: FieldKind + ui_hidden: bool + ui_type: Optional[UIType] + ui_order: Optional[int] + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + ) + + +def InputField( + # copied from pydantic's Field + # TODO: Can we support default_factory? + default: Any = _Unset, + default_factory: Callable[[], Any] | None = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + # custom + input: Input = Input.Any, + ui_type: Optional[UIType] = None, + ui_component: Optional[UIComponent] = None, + ui_hidden: bool = False, + ui_order: Optional[int] = None, + ui_choice_labels: Optional[dict[str, str]] = None, +) -> Any: + """ + Creates an input field for an invocation. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param Input input: [Input.Any] The kind of input this field requires. \ + `Input.Direct` means a value must be provided on instantiation. \ + `Input.Connection` means the value must be provided by a connection. \ + `Input.Any` means either will do. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ + The UI will always render a suitable component, but sometimes you want something different than the default. \ + For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ + For this case, you could provide `UIComponent.Textarea`. + + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. + + :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. + """ + + json_schema_extra_ = InputFieldJSONSchemaExtra( + input=input, + ui_type=ui_type, + ui_component=ui_component, + ui_hidden=ui_hidden, + ui_order=ui_order, + ui_choice_labels=ui_choice_labels, + field_kind=FieldKind.Input, + orig_required=True, + ) + + """ + There is a conflict between the typing of invocation definitions and the typing of an invocation's + `invoke()` function. + + On instantiation of a node, the invocation definition is used to create the python class. At this time, + any number of fields may be optional, because they may be provided by connections. + + On calling of `invoke()`, however, those fields may be required. + + For example, consider an ResizeImageInvocation with an `image: ImageField` field. + + `image` is required during the call to `invoke()`, but when the python class is instantiated, + the field may not be present. This is fine, because that image field will be provided by a + connection from an ancestor node, which outputs an image. + + This means we want to type the `image` field as optional for the node class definition, but required + for the `invoke()` function. + + If we use `typing.Optional` in the node class definition, the field will be typed as optional in the + `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or + any static type analysis tools will complain. + + To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, + but secretly make them optional in `InputField()`. We also store the original required bool and/or default + value. When we call `invoke()`, we use this stored information to do an additional check on the class. + """ + + if default_factory is not _Unset and default_factory is not None: + default = default_factory() + logger.warn('"default_factory" is not supported, calling it now to set "default"') + + # These are the args we may wish pass to the pydantic `Field()` function + field_args = { + "default": default, + "title": title, + "description": description, + "pattern": pattern, + "strict": strict, + "gt": gt, + "ge": ge, + "lt": lt, + "le": le, + "multiple_of": multiple_of, + "allow_inf_nan": allow_inf_nan, + "max_digits": max_digits, + "decimal_places": decimal_places, + "min_length": min_length, + "max_length": max_length, + } + + # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected + provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} + + # Because we are manually making fields optional, we need to store the original required bool for reference later + json_schema_extra_.orig_required = default is PydanticUndefined + + # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one + if input is Input.Any or input is Input.Connection: + default_ = None if default is PydanticUndefined else default + provided_args.update({"default": default_}) + if default is not PydanticUndefined: + # Before invoking, we'll check for the original default value and set it on the field if the field has no value + json_schema_extra_.default = default + json_schema_extra_.orig_default = default + elif default is not PydanticUndefined: + default_ = default + provided_args.update({"default": default_}) + json_schema_extra_.orig_default = default_ + + return Field( + **provided_args, + json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), + ) + + +def OutputField( + # copied from pydantic's Field + default: Any = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + # custom + ui_type: Optional[UIType] = None, + ui_hidden: bool = False, + ui_order: Optional[int] = None, +) -> Any: + """ + Creates an output field for an invocation output. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + """ + return Field( + default=default, + title=title, + description=description, + pattern=pattern, + strict=strict, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_length=min_length, + max_length=max_length, + json_schema_extra=OutputFieldJSONSchemaExtra( + ui_type=ui_type, + ui_hidden=ui_hidden, + ui_order=ui_order, + field_kind=FieldKind.Output, + ).model_dump(exclude_none=True), + ) + # endregion diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index f729d60cdd5..16d0f33dda3 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,19 +7,16 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, WithMetadata from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker from .baseinvocation import ( BaseInvocation, Classification, - Input, - InputField, InvocationContext, - WithMetadata, invocation, ) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index c3d00bb1330..d4d3d5bea44 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -13,7 +13,8 @@ from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 6bd28896244..c01e0ed0fb2 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -7,16 +7,13 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b77363ceb86..909c307481e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,6 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, WithMetadata from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.primitives import ( DenoiseMaskField, @@ -35,7 +36,6 @@ ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus @@ -59,12 +59,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index defc61275fe..6ca53011f0b 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -5,10 +5,10 @@ import numpy as np from pydantic import ValidationInfo, field_validator +from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput -from invokeai.app.shared.fields import FieldDescriptions -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 14d66f8ef68..399e217dc17 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -5,20 +5,16 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - MetadataField, - OutputField, - UIType, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import ControlField +from invokeai.app.invocations.fields import FieldDescriptions, InputField, MetadataField, OutputField, UIType from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.shared.fields import FieldDescriptions from ...version import __version__ diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 99dcc72999b..c710c9761b0 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -3,17 +3,14 @@ from pydantic import BaseModel, ConfigDict, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.shared.models import FreeUConfig from ...backend.model_management import BaseModelType, ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index b1ee91e1cdf..2e717ac561b 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,17 +4,15 @@ import torch from pydantic import field_validator +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField from invokeai.app.invocations.latent import LatentsField -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 759cfde700f..b43d7eaef2c 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -11,9 +11,17 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from tqdm import tqdm +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + OutputField, + UIComponent, + UIType, + WithMetadata, +) from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend import BaseModelType, ModelType, SubModelType @@ -24,13 +32,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, - UIType, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index dccd18f754b..dab9c3dc0f4 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -41,7 +41,8 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField @invocation( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index afe8ff06d9d..22f03454a55 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -5,16 +5,12 @@ import torch from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 4778d980771..94b4a217ae7 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -7,7 +7,8 @@ from invokeai.app.invocations.primitives import StringCollectionOutput -from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, UIComponent @invocation( diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 68076fdfeb1..62df5bc8047 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,14 +1,10 @@ -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index 3466206b377..ccbc2f6d924 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -5,13 +5,11 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) +from .fields import InputField, OutputField, UIComponent from .primitives import StringOutput diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index e055d23903f..66ac87c37b8 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -5,17 +5,14 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.model_management.models.base import BaseModelType diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index e51f891a8db..bdc23ef6edd 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,14 +8,11 @@ BaseInvocation, BaseInvocationOutput, Classification, - Input, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) +from invokeai.app.invocations.fields import Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.backend.tiles.tiles import ( diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 5f715c1a7ed..2cab279a9fc 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -14,7 +14,8 @@ from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata # TODO: Populate this from disk? # TODO: Use model manager to load? diff --git a/invokeai/app/services/image_files/image_files_base.py b/invokeai/app/services/image_files/image_files_base.py index 27dd67531f4..f4036277b72 100644 --- a/invokeai/app/services/image_files/image_files_base.py +++ b/invokeai/app/services/image_files/image_files_base.py @@ -4,7 +4,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 08448216723..fb687973bad 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -7,7 +7,7 @@ from PIL.Image import Image as PILImageType from send2trash import send2trash -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.invoker import Invoker from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 727f4977fba..7b7b261ecab 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Optional -from invokeai.app.invocations.metadata import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.shared.pagination import OffsetPaginatedResults from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 74f82e7d84c..5b37913c8fd 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Optional, Union, cast -from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator +from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index df71dadb5b0..42c42667744 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -3,7 +3,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, ImageRecord, diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index ff21731a506..adeed738119 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -2,7 +2,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 1acf165abac..ba05b050c5b 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -13,14 +13,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, invocation, invocation_output, ) +from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" diff --git a/invokeai/app/shared/fields.py b/invokeai/app/shared/fields.py deleted file mode 100644 index 3e841ffbf22..00000000000 --- a/invokeai/app/shared/fields.py +++ /dev/null @@ -1,67 +0,0 @@ -class FieldDescriptions: - denoising_start = "When to start denoising, expressed a percentage of total steps" - denoising_end = "When to stop denoising, expressed a percentage of total steps" - cfg_scale = "Classifier-Free Guidance scale" - cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" - scheduler = "Scheduler to use during inference" - positive_cond = "Positive conditioning tensor" - negative_cond = "Negative conditioning tensor" - noise = "Noise tensor" - clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" - unet = "UNet (scheduler, LoRAs)" - vae = "VAE" - cond = "Conditioning tensor" - controlnet_model = "ControlNet model to load" - vae_model = "VAE model to load" - lora_model = "LoRA model to load" - main_model = "Main model (UNet, VAE, CLIP) to load" - sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" - sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" - onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" - lora_weight = "The weight at which the LoRA is applied to each model" - compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" - raw_prompt = "Raw prompt text (no parsing)" - sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" - skipped_layers = "Number of layers to skip in text encoder" - seed = "Seed for random number generation" - steps = "Number of steps to run" - width = "Width of output (px)" - height = "Height of output (px)" - control = "ControlNet(s) to apply" - ip_adapter = "IP-Adapter to apply" - t2i_adapter = "T2I-Adapter(s) to apply" - denoised_latents = "Denoised latents tensor" - latents = "Latents tensor" - strength = "Strength of denoising (proportional to steps)" - metadata = "Optional metadata to be saved with the image" - metadata_collection = "Collection of Metadata" - metadata_item_polymorphic = "A single metadata item or collection of metadata items" - metadata_item_label = "Label for this metadata item" - metadata_item_value = "The value for this metadata item (may be any type)" - workflow = "Optional workflow to be saved with the image" - interp_mode = "Interpolation mode" - torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" - fp32 = "Whether or not to use full float32 precision" - precision = "Precision to use" - tiled = "Processing using overlapping tiles (reduce memory consumption)" - detect_res = "Pixel resolution for detection" - image_res = "Pixel resolution for output image" - safe_mode = "Whether or not to use safe mode" - scribble_mode = "Whether or not to use scribble mode" - scale_factor = "The factor by which to scale" - blend_alpha = ( - "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." - ) - num_1 = "The first number" - num_2 = "The second number" - mask = "The mask to use for the operation" - board = "The board to save the image to" - image = "The image to process" - tile_size = "Tile size" - inclusive_low = "The inclusive low value" - exclusive_high = "The exclusive high value" - decimal_places = "The number of decimal places to round to" - freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' - freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' - freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." - freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." diff --git a/invokeai/app/shared/models.py b/invokeai/app/shared/models.py index ed68cb287e3..1a11b480cc5 100644 --- a/invokeai/app/shared/models.py +++ b/invokeai/app/shared/models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions class FreeUConfig(BaseModel): diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index bca4e1011f8..e71daad3f3a 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -3,12 +3,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) +from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField From f0e60a4ba22822a429bf1108a05d17f9e73a0ddb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 18:02:58 +1100 Subject: [PATCH 02/67] feat(nodes): restricts invocation context power Creates a low-power `InvocationContext` with simplified methods and data. See `invocation_context.py` for detailed comments. --- .../app/services/shared/invocation_context.py | 408 ++++++++++++++++++ invokeai/app/util/step_callback.py | 39 +- 2 files changed, 434 insertions(+), 13 deletions(-) create mode 100644 invokeai/app/services/shared/invocation_context.py diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py new file mode 100644 index 00000000000..c0aaac54f87 --- /dev/null +++ b/invokeai/app/services/shared/invocation_context.py @@ -0,0 +1,408 @@ +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Optional + +from PIL.Image import Image +from pydantic import ConfigDict +from torch import Tensor + +from invokeai.app.invocations.compel import ConditioningFieldData +from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.app.util.misc import uuid_string +from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState + +if TYPE_CHECKING: + from invokeai.app.invocations.baseinvocation import BaseInvocation + +""" +The InvocationContext provides access to various services and data about the current invocation. + +We do not provide the invocation services directly, as their methods are both dangerous and +inconvenient to use. + +For example: +- The `images` service allows nodes to delete or unsafely modify existing images. +- The `configuration` service allows nodes to change the app's config at runtime. +- The `events` service allows nodes to emit arbitrary events. + +Wrapping these services provides a simpler and safer interface for nodes to use. + +When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere +with each other. + +Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. +""" + + +@dataclass(frozen=True) +class InvocationContextData: + invocation: "BaseInvocation" + session_id: str + queue_id: str + source_node_id: str + queue_item_id: int + batch_id: str + workflow: Optional[WorkflowWithoutID] = None + + +class LoggerInterface: + def __init__(self, services: InvocationServices) -> None: + def debug(message: str) -> None: + """ + Logs a debug message. + + :param message: The message to log. + """ + services.logger.debug(message) + + def info(message: str) -> None: + """ + Logs an info message. + + :param message: The message to log. + """ + services.logger.info(message) + + def warning(message: str) -> None: + """ + Logs a warning message. + + :param message: The message to log. + """ + services.logger.warning(message) + + def error(message: str) -> None: + """ + Logs an error message. + + :param message: The message to log. + """ + services.logger.error(message) + + self.debug = debug + self.info = info + self.warning = warning + self.error = error + + +class ImagesInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def save( + image: Image, + board_id: Optional[str] = None, + image_category: ImageCategory = ImageCategory.GENERAL, + metadata: Optional[MetadataField] = None, + ) -> ImageDTO: + """ + Saves an image, returning its DTO. + + If the current queue item has a workflow, it is automatically saved with the image. + + :param image: The image to save, as a PIL image. + :param board_id: The board ID to add the image to, if it should be added. + :param image_category: The category of the image. Only the GENERAL category is added to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \ + from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \ + override or provide metadata manually. + """ + + # If the invocation inherits metadata, use that. Else, use the metadata passed in. + metadata_ = ( + context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata + ) + + return services.images.create( + image=image, + is_intermediate=context_data.invocation.is_intermediate, + image_category=image_category, + board_id=board_id, + metadata=metadata_, + image_origin=ResourceOrigin.INTERNAL, + workflow=context_data.workflow, + session_id=context_data.session_id, + node_id=context_data.invocation.id, + ) + + def get_pil(image_name: str) -> Image: + """ + Gets an image as a PIL Image object. + + :param image_name: The name of the image to get. + """ + return services.images.get_pil_image(image_name) + + def get_metadata(image_name: str) -> Optional[MetadataField]: + """ + Gets an image's metadata, if it has any. + + :param image_name: The name of the image to get the metadata for. + """ + return services.images.get_metadata(image_name) + + def get_dto(image_name: str) -> ImageDTO: + """ + Gets an image as an ImageDTO object. + + :param image_name: The name of the image to get. + """ + return services.images.get_dto(image_name) + + def update( + image_name: str, + board_id: Optional[str] = None, + is_intermediate: Optional[bool] = False, + ) -> ImageDTO: + """ + Updates an image, returning its updated DTO. + + It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. + + If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to + get the updated image. + + :param image_name: The name of the image to update. + :param board_id: The board ID to add the image to, if it should be added. + :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. + """ + if is_intermediate is not None: + services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) + if board_id is None: + services.board_images.remove_image_from_board(image_name) + else: + services.board_images.add_image_to_board(image_name, board_id) + return services.images.get_dto(image_name) + + self.save = save + self.get_pil = get_pil + self.get_metadata = get_metadata + self.get_dto = get_dto + self.update = update + + +class LatentsKind(str, Enum): + IMAGE = "image" + NOISE = "noise" + MASK = "mask" + MASKED_IMAGE = "masked_image" + OTHER = "other" + + +class LatentsInterface: + def __init__( + self, + services: InvocationServices, + context_data: InvocationContextData, + ) -> None: + def save(tensor: Tensor) -> str: + """ + Saves a latents tensor, returning its name. + + :param tensor: The latents tensor to save. + """ + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" + services.latents.save( + name=name, + data=tensor, + ) + return name + + def get(latents_name: str) -> Tensor: + """ + Gets a latents tensor by name. + + :param latents_name: The name of the latents tensor to get. + """ + return services.latents.get(latents_name) + + self.save = save + self.get = get + + +class ConditioningInterface: + def __init__( + self, + services: InvocationServices, + context_data: InvocationContextData, + ) -> None: + def save(conditioning_data: ConditioningFieldData) -> str: + """ + Saves a conditioning data object, returning its name. + + :param conditioning_data: The conditioning data to save. + """ + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" + services.latents.save( + name=name, + data=conditioning_data, # type: ignore [arg-type] + ) + return name + + def get(conditioning_name: str) -> Tensor: + """ + Gets conditioning data by name. + + :param conditioning_name: The name of the conditioning data to get. + """ + return services.latents.get(conditioning_name) + + self.save = save + self.get = get + + +class ModelsInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + """ + Checks if a model exists. + + :param model_name: The name of the model to check. + :param base_model: The base model of the model to check. + :param model_type: The type of the model to check. + """ + return services.model_manager.model_exists(model_name, base_model, model_type) + + def load( + model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Loads a model, returning its `ModelInfo` object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + :param submodel: The submodel of the model to get. + """ + return services.model_manager.get_model( + model_name, base_model, model_type, submodel, context_data=context_data + ) + + def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Gets a model's info, an dict-like object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + """ + return services.model_manager.model_info(model_name, base_model, model_type) + + self.exists = exists + self.load = load + self.get_info = get_info + + +class ConfigInterface: + def __init__(self, services: InvocationServices) -> None: + def get() -> InvokeAIAppConfig: + """ + Gets the app's config. + """ + # The config can be changed at runtime. We don't want nodes doing this, so we make a + # frozen copy.. + config = services.configuration.get_config() + frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) + return frozen_config + + self.get = get + + +class UtilInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def sd_step_callback( + intermediate_state: PipelineIntermediateState, + base_model: BaseModelType, + ) -> None: + """ + The step callback emits a progress event with the current step, the total number of + steps, a preview image, and some other internal metadata. + + This should be called after each step of the diffusion process. + + :param intermediate_state: The intermediate state of the diffusion pipeline. + :param base_model: The base model for the current denoising step. + """ + stable_diffusion_step_callback( + context_data=context_data, + intermediate_state=intermediate_state, + base_model=base_model, + invocation_queue=services.queue, + events=services.events, + ) + + self.sd_step_callback = sd_step_callback + + +class InvocationContext: + """ + The invocation context provides access to various services and data about the current invocation. + """ + + def __init__( + self, + images: ImagesInterface, + latents: LatentsInterface, + models: ModelsInterface, + config: ConfigInterface, + logger: LoggerInterface, + data: InvocationContextData, + util: UtilInterface, + conditioning: ConditioningInterface, + ) -> None: + self.images = images + "Provides methods to save, get and update images and their metadata." + self.logger = logger + "Provides access to the app logger." + self.latents = latents + "Provides methods to save and get latents tensors, including image, noise, masks, and masked images." + self.conditioning = conditioning + "Provides methods to save and get conditioning data." + self.models = models + "Provides methods to check if a model exists, get a model, and get a model's info." + self.config = config + "Provides access to the app's config." + self.data = data + "Provides data about the current queue item and invocation." + self.util = util + "Provides utility methods." + + +def build_invocation_context( + services: InvocationServices, + context_data: InvocationContextData, +) -> InvocationContext: + """ + Builds the invocation context. This is a wrapper around the invocation services that provides + a more convenient (and less dangerous) interface for nodes to use. + + :param invocation_services: The invocation services to wrap. + :param invocation_context_data: The invocation context data. + """ + + logger = LoggerInterface(services=services) + images = ImagesInterface(services=services, context_data=context_data) + latents = LatentsInterface(services=services, context_data=context_data) + models = ModelsInterface(services=services, context_data=context_data) + config = ConfigInterface(services=services) + util = UtilInterface(services=services, context_data=context_data) + conditioning = ConditioningInterface(services=services, context_data=context_data) + + ctx = InvocationContext( + images=images, + logger=logger, + config=config, + latents=latents, + models=models, + data=context_data, + util=util, + conditioning=conditioning, + ) + + return ctx diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index f166206d528..5cc3caa9ba5 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,12 +1,25 @@ +from typing import Protocol + import torch from PIL import Image +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC +from invokeai.app.services.shared.invocation_context import InvocationContextData from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL -from ..invocations.baseinvocation import InvocationContext + + +class StepCallback(Protocol): + def __call__( + self, + intermediate_state: PipelineIntermediateState, + base_model: BaseModelType, + ) -> None: + ... def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -25,13 +38,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( - context: InvocationContext, + context_data: InvocationContextData, intermediate_state: PipelineIntermediateState, - node: dict, - source_node_id: str, base_model: BaseModelType, -): - if context.services.queue.is_canceled(context.graph_execution_state_id): + invocation_queue: InvocationQueueABC, + events: EventServiceBase, +) -> None: + if invocation_queue.is_canceled(context_data.session_id): raise CanceledException # Some schedulers report not only the noisy latents at the current timestep, @@ -108,13 +121,13 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - context.services.events.emit_generator_progress( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - node=node, - source_node_id=source_node_id, + events.emit_generator_progress( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, + node_id=context_data.invocation.id, + source_node_id=context_data.source_node_id, progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), step=intermediate_state.step, order=intermediate_state.order, From 97a6c6eea799a431a927677e797cb2a5c544ba56 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:01 +1100 Subject: [PATCH 03/67] feat: add pyright config I was having issues with mypy bother over- and under-reporting certain problems. I've added a pyright config. --- pyproject.toml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7f4b0d77f25..d063f1ad0ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -280,3 +280,19 @@ module = [ "invokeai.frontend.install.model_install", ] #=== End: MyPy + +[tool.pyright] +include = [ + "invokeai/app/invocations/" +] +exclude = [ + "**/node_modules", + "**/__pycache__", + "invokeai/app/invocations/onnx.py", + "invokeai/app/api/routers/models.py", + "invokeai/app/services/invocation_stats/invocation_stats_default.py", + "invokeai/app/services/model_manager/model_manager_base.py", + "invokeai/app/services/model_manager/model_manager_default.py", + "invokeai/app/services/model_records/model_records_sql.py", + "invokeai/app/util/controlnet_utils.py", +] From 7e5ba2795e0bd9d90e82150899eef8b8edc34f8f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:16 +1100 Subject: [PATCH 04/67] feat(nodes): update all invocations to use new invocation context Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them. Supporting minor changes: - Patch bump for all nodes that use the context - Update invocation processor to provide new context - Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node - Minor change to `ModelManagerService` to support the new wrapped context - Fanagling of imports to avoid circular dependencies --- invokeai/app/invocations/baseinvocation.py | 54 +- invokeai/app/invocations/collections.py | 8 +- invokeai/app/invocations/compel.py | 105 ++-- .../controlnet_image_processors.py | 56 +- invokeai/app/invocations/cv.py | 32 +- invokeai/app/invocations/facetools.py | 129 ++-- invokeai/app/invocations/fields.py | 57 +- invokeai/app/invocations/image.py | 586 +++++------------- invokeai/app/invocations/infill.py | 123 +--- invokeai/app/invocations/ip_adapter.py | 9 +- invokeai/app/invocations/latent.py | 238 +++---- invokeai/app/invocations/math.py | 22 +- invokeai/app/invocations/metadata.py | 19 +- invokeai/app/invocations/model.py | 29 +- invokeai/app/invocations/noise.py | 25 +- invokeai/app/invocations/onnx.py | 10 +- invokeai/app/invocations/param_easing.py | 44 +- invokeai/app/invocations/primitives.py | 136 ++-- invokeai/app/invocations/prompt.py | 6 +- invokeai/app/invocations/sdxl.py | 13 +- invokeai/app/invocations/strings.py | 11 +- invokeai/app/invocations/t2i_adapter.py | 6 +- invokeai/app/invocations/tiles.py | 38 +- invokeai/app/invocations/upscale.py | 33 +- invokeai/app/services/events/events_base.py | 4 +- .../invocation_processor_default.py | 24 +- .../model_manager/model_manager_base.py | 9 +- .../model_manager/model_manager_common.py | 0 .../model_manager/model_manager_default.py | 44 +- invokeai/app/services/shared/graph.py | 7 +- .../app/services/shared/invocation_context.py | 9 +- invokeai/app/util/step_callback.py | 23 +- 32 files changed, 717 insertions(+), 1192 deletions(-) create mode 100644 invokeai/app/services/model_manager/model_manager_common.py diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 395d5e98707..c4aed1fac5a 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -16,10 +16,16 @@ from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from invokeai.app.invocations.fields import FieldKind, Input +from invokeai.app.invocations.fields import ( + FieldDescriptions, + FieldKind, + Input, + InputFieldJSONSchemaExtra, + MetadataField, + logger, +) from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string from invokeai.backend.util.logging import InvokeAILogger @@ -219,7 +225,7 @@ def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput: """ Internal invoke method, calls `invoke()` after some prep. Handles optional fields that are required to call `invoke()` and invocation cache. @@ -244,23 +250,23 @@ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: raise MissingInputException(self.model_fields["type"].default, field_name) # skip node cache codepath if it's disabled - if context.services.configuration.node_cache_size == 0: + if services.configuration.node_cache_size == 0: return self.invoke(context) output: BaseInvocationOutput if self.use_cache: - key = context.services.invocation_cache.create_key(self) - cached_value = context.services.invocation_cache.get(key) + key = services.invocation_cache.create_key(self) + cached_value = services.invocation_cache.get(key) if cached_value is None: - context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') output = self.invoke(context) - context.services.invocation_cache.save(key, output) + services.invocation_cache.save(key, output) return output else: - context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') return cached_value else: - context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') + services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') return self.invoke(context) id: str = Field( @@ -513,3 +519,29 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper + + +class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + + metadata: Optional[MetadataField] = Field( + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class WithWorkflow: + workflow = None + + def __init_subclass__(cls) -> None: + logger.warn( + f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." + ) + super().__init_subclass__() diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index d35a9d79c74..f5709b4ba36 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -27,7 +27,7 @@ def stop_gt_start(cls, v: int, info: ValidationInfo): raise ValueError("stop must be greater than start") return v - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) @@ -45,7 +45,7 @@ class RangeOfSizeInvocation(BaseInvocation): size: int = InputField(default=1, gt=0, description="The number of values") step: int = InputField(default=1, description="The step of the range") - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput( collection=list(range(self.start, self.start + (self.step * self.size), self.step)) ) @@ -72,6 +72,6 @@ class RandomRangeInvocation(BaseInvocation): description="The seed for the RNG (omit for random)", ) - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b386aef2cbe..b4496031bc4 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,12 +1,18 @@ -from dataclasses import dataclass -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput +from invokeai.app.invocations.fields import ( + ConditioningFieldData, + FieldDescriptions, + Input, + InputField, + OutputField, + UIComponent, +) +from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, @@ -20,16 +26,14 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from .model import ClipField +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] # unconditioned: Optional[torch.Tensor] @@ -44,7 +48,7 @@ class ConditioningFieldData: title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -61,26 +65,18 @@ class CompelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - context=context, - ) + def invoke(self, context) -> ConditioningOutput: + tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) def _lora_loader(): for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): @@ -89,11 +85,10 @@ def _lora_loader(): ti_list.append( ( name, - context.services.model_manager.get_model( + context.models.load( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -124,7 +119,7 @@ def _lora_loader(): conjunction = Compel.parse_prompt_string(self.prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -145,34 +140,23 @@ def _lora_loader(): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) class SDXLPromptInvocationBase: def run_clip_compel( self, - context: InvocationContext, + context: "InvocationContext", clip_field: ClipField, prompt: str, get_pooled: bool, lora_prefix: str, zero_on_empty: bool, ): - tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) # return zero on empty if prompt == "" and zero_on_empty: @@ -196,14 +180,12 @@ def run_clip_compel( def _lora_loader(): for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(prompt): @@ -212,11 +194,10 @@ def _lora_loader(): ti_list.append( ( name, - context.services.model_manager.get_model( + context.models.load( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -249,7 +230,7 @@ def _lora_loader(): conjunction = Compel.parse_prompt_string(prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: # TODO: better logging for and syntax log_tokenization_for_conjunction(conjunction, tokenizer) @@ -282,7 +263,7 @@ def _lora_loader(): title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -307,7 +288,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -364,14 +345,9 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation( @@ -379,7 +355,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -397,7 +373,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -417,14 +393,9 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation_output("clip_skip_output") @@ -447,7 +418,7 @@ class ClipSkipInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) - def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: + def invoke(self, context) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers return ClipSkipInvocationOutput( clip=self.clip, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 9b652b8eee9..3797722c93e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,18 +25,17 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput +from invokeai.app.invocations.baseinvocation import WithMetadata +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector +from invokeai.backend.model_management.models.base import BaseModelType -from ...backend.model_management import BaseModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -121,7 +120,7 @@ def validate_begin_end_step_percent(self) -> "ControlNetInvocation": validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> ControlOutput: + def invoke(self, context) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, @@ -145,23 +144,14 @@ def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image - def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery - image_dto = context.services.images.create( - image=processed_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.CONTROL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=processed_image) """Builds an ImageOutput and its ImageField""" processed_image_field = ImageField(image_name=image_dto.image_name) @@ -180,7 +170,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Canny Processor", tags=["controlnet", "canny"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" @@ -203,7 +193,7 @@ def run_processor(self, image): title="HED (softedge) Processor", tags=["controlnet", "hed", "softedge"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" @@ -232,7 +222,7 @@ def run_processor(self, image): title="Lineart Processor", tags=["controlnet", "lineart"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" @@ -254,7 +244,7 @@ def run_processor(self, image): title="Lineart Anime Processor", tags=["controlnet", "lineart", "anime"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" @@ -277,7 +267,7 @@ def run_processor(self, image): title="Midas Depth Processor", tags=["controlnet", "midas"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" @@ -304,7 +294,7 @@ def run_processor(self, image): title="Normal BAE Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" @@ -321,7 +311,7 @@ def run_processor(self, image): @invocation( - "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0" + "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1" ) class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" @@ -344,7 +334,7 @@ def run_processor(self, image): @invocation( - "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0" + "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1" ) class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" @@ -371,7 +361,7 @@ def run_processor(self, image): title="Content Shuffle Processor", tags=["controlnet", "contentshuffle"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" @@ -401,7 +391,7 @@ def run_processor(self, image): title="Zoe (Depth) Processor", tags=["controlnet", "zoe", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" @@ -417,7 +407,7 @@ def run_processor(self, image): title="Mediapipe Face Processor", tags=["controlnet", "mediapipe", "face"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" @@ -440,7 +430,7 @@ def run_processor(self, image): title="Leres (Depth) Processor", tags=["controlnet", "leres", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" @@ -469,7 +459,7 @@ def run_processor(self, image): title="Tile Resample Processor", tags=["controlnet", "tile"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class TileResamplerProcessorInvocation(ImageProcessorInvocation): """Tile resampler processor""" @@ -509,7 +499,7 @@ def run_processor(self, img): title="Segment Anything Processor", tags=["controlnet", "segmentanything"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" @@ -551,7 +541,7 @@ def show_anns(self, anns: List[Dict]): title="Color Map Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ColorMapImageProcessorInvocation(ImageProcessorInvocation): """Generates a color map from the provided image""" diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 5865338e192..375b18f9c58 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -5,23 +5,23 @@ import numpy from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata -@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0") +@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1") class CvInpaintInvocation(BaseInvocation, WithMetadata): """Simple inpaint using opencv.""" image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - mask = context.services.images.get_pil_image(self.mask.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + mask = context.images.get_pil(self.mask.image_name) # Convert to cv image/mask # TODO: consider making these utility functions @@ -35,18 +35,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # TODO: consider making a utility function image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) - image_dto = context.services.images.create( - image=image_inpainted, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=image_inpainted) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 13f1066ec3e..2c92e28cfe0 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -1,7 +1,7 @@ import math import re from pathlib import Path -from typing import Optional, TypedDict +from typing import TYPE_CHECKING, Optional, TypedDict import cv2 import numpy as np @@ -13,13 +13,16 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - InvocationContext, + WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory + +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -174,7 +177,7 @@ def prepare_faces_list( def generate_face_box_mask( - context: InvocationContext, + context: "InvocationContext", minimum_confidence: float, x_offset: float, y_offset: float, @@ -273,7 +276,7 @@ def generate_face_box_mask( def extract_face( - context: InvocationContext, + context: "InvocationContext", image: ImageType, face: FaceResultData, padding: int, @@ -304,37 +307,37 @@ def extract_face( # Adjust the crop boundaries to stay within the original image's dimensions if x_min < 0: - context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.") + context.logger.warning("FaceTools --> -X-axis padding reached image edge.") x_max -= x_min x_min = 0 elif x_max > mask.width: - context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.") + context.logger.warning("FaceTools --> +X-axis padding reached image edge.") x_min -= x_max - mask.width x_max = mask.width if y_min < 0: - context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> +Y-axis padding reached image edge.") y_max -= y_min y_min = 0 elif y_max > mask.height: - context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> -Y-axis padding reached image edge.") y_min -= y_max - mask.height y_max = mask.height # Ensure the crop is square and adjust the boundaries if needed if x_max - x_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") diff = crop_size - (x_max - x_min) x_min -= diff // 2 x_max += diff - diff // 2 if y_max - y_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") diff = crop_size - (y_max - y_min) y_min -= diff // 2 y_max += diff - diff // 2 - context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") + context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") # Crop the output image to the specified size with the center of the face mesh as the center. mask = mask.crop((x_min, y_min, x_max, y_max)) @@ -354,7 +357,7 @@ def extract_face( def get_faces_list( - context: InvocationContext, + context: "InvocationContext", image: ImageType, should_chunk: bool, minimum_confidence: float, @@ -366,7 +369,7 @@ def get_faces_list( # Generate the face box mask and get the center of the face. if not should_chunk: - context.services.logger.info("FaceTools --> Attempting full image face detection.") + context.logger.info("FaceTools --> Attempting full image face detection.") result = generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -378,7 +381,7 @@ def get_faces_list( draw_mesh=draw_mesh, ) if should_chunk or len(result) == 0: - context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") + context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") width, height = image.size image_chunks = [] x_offsets = [] @@ -397,7 +400,7 @@ def get_faces_list( x_offsets.append(x) y_offsets.append(0) fx += increment - context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}") + context.logger.info(f"FaceTools --> Chunk starting at x = {x}") elif height > width: # Portrait - slice the image vertically fy = 0.0 @@ -409,10 +412,10 @@ def get_faces_list( x_offsets.append(0) y_offsets.append(y) fy += increment - context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}") + context.logger.info(f"FaceTools --> Chunk starting at y = {y}") for idx in range(len(image_chunks)): - context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") + context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") result = result + generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -426,7 +429,7 @@ def get_faces_list( if len(result) == 0: # Give up - context.services.logger.warning( + context.logger.warning( "FaceTools --> No face detected in chunked input image. Passing through original image." ) @@ -435,7 +438,7 @@ def get_faces_list( return all_faces -@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0") +@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1") class FaceOffInvocation(BaseInvocation, WithMetadata): """Bound, extract, and mask a face from an image using MediaPipe detection""" @@ -456,7 +459,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]: + def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]: all_faces = get_faces_list( context=context, image=image, @@ -468,11 +471,11 @@ def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[Extr ) if len(all_faces) == 0: - context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.") + context.logger.warning("FaceOff --> No faces detected. Passing through original image.") return None if self.face_id > len(all_faces) - 1: - context.services.logger.warning( + context.logger.warning( f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image." ) return None @@ -483,8 +486,8 @@ def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[Extr return face_data - def invoke(self, context: InvocationContext) -> FaceOffOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> FaceOffOutput: + image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) if result is None: @@ -498,24 +501,9 @@ def invoke(self, context: InvocationContext) -> FaceOffOutput: x = result["x_min"] y = result["y_min"] - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - mask_dto = context.services.images.create( - image=result_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK) output = FaceOffOutput( image=ImageField(image_name=image_dto.image_name), @@ -529,7 +517,7 @@ def invoke(self, context: InvocationContext) -> FaceOffOutput: return output -@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0") +@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1") class FaceMaskInvocation(BaseInvocation, WithMetadata): """Face mask creation using mediapipe face detection""" @@ -556,7 +544,7 @@ def validate_comma_separated_ints(cls, v) -> str: raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")') return v - def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult: + def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult: all_faces = get_faces_list( context=context, image=image, @@ -578,7 +566,7 @@ def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResu if len(intersected_face_ids) == 0: id_range_str = ",".join([str(id) for id in id_range]) - context.services.logger.warning( + context.logger.warning( f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image." ) return FaceMaskResult( @@ -613,28 +601,13 @@ def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResu mask=mask_pil, ) - def invoke(self, context: InvocationContext) -> FaceMaskOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> FaceMaskOutput: + image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) - image_dto = context.services.images.create( - image=result["image"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result["image"]) - mask_dto = context.services.images.create( - image=result["mask"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK) output = FaceMaskOutput( image=ImageField(image_name=image_dto.image_name), @@ -647,7 +620,7 @@ def invoke(self, context: InvocationContext) -> FaceMaskOutput: @invocation( - "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0" + "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1" ) class FaceIdentifierInvocation(BaseInvocation, WithMetadata): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" @@ -661,7 +634,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType: + def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType: image = image.copy() all_faces = get_faces_list( @@ -702,22 +675,10 @@ def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageT return image - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 0cce8e3c6b5..566babbb6b7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,11 +1,13 @@ +from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.util.metaenum import MetaEnum +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -255,6 +257,10 @@ class InputFieldJSONSchemaExtra(BaseModel): class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + metadata: Optional[MetadataField] = Field( default=None, description=FieldDescriptions.metadata, @@ -498,4 +504,53 @@ def OutputField( field_kind=FieldKind.Output, ).model_dump(exclude_none=True), ) + + +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class BoardField(BaseModel): + """A board primitive field""" + + board_id: str = Field(description="The id of the board") + + +class DenoiseMaskField(BaseModel): + """An inpaint mask field""" + + mask_name: str = Field(description="The name of the mask image") + masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") # endregion diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 16d0f33dda3..10ebd97ace3 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,30 +7,36 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, WithMetadata -from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.invocations.baseinvocation import WithMetadata +from invokeai.app.invocations.fields import ( + BoardField, + ColorField, + FieldDescriptions, + ImageField, + Input, + InputField, +) +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker from .baseinvocation import ( BaseInvocation, Classification, - InvocationContext, invocation, ) -@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0") +@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1") class ShowImageInvocation(BaseInvocation): """Displays a provided image using the OS image viewer, and passes it forward in the pipeline.""" image: ImageField = InputField(description="The image to show") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - if image: - image.show() + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + image.show() # TODO: how to handle failure? @@ -46,7 +52,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blank Image", tags=["image"], category="image", - version="1.2.0", + version="1.2.1", ) class BlankImageInvocation(BaseInvocation, WithMetadata): """Creates a blank image and forwards it to the pipeline""" @@ -56,25 +62,12 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -82,7 +75,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Crop Image", tags=["image", "crop"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageCropInvocation(BaseInvocation, WithMetadata): """Crops an image to a specified box. The box can be outside of the image.""" @@ -93,28 +86,15 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) image_crop.paste(image, (-self.x, -self.y)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -145,8 +125,8 @@ class CenterPadCropInvocation(BaseInvocation): description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions new_width = image.width + self.right + self.left @@ -156,20 +136,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Paste new image onto input image_crop.paste(image, (self.left, self.top)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -177,7 +146,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Paste Image", tags=["image", "paste"], category="image", - version="1.2.0", + version="1.2.1", ) class ImagePasteInvocation(BaseInvocation, WithMetadata): """Pastes an image into another image.""" @@ -192,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): y: int = InputField(default=0, description="The top y coordinate at which to paste the image") crop: bool = InputField(default=False, description="Crop to base image dimensions") - def invoke(self, context: InvocationContext) -> ImageOutput: - base_image = context.services.images.get_pil_image(self.base_image.image_name) - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + base_image = context.images.get_pil(self.base_image.image_name) + image = context.images.get_pil(self.image.image_name) mask = None if self.mask is not None: - mask = context.services.images.get_pil_image(self.mask.image_name) + mask = context.images.get_pil(self.mask.image_name) mask = ImageOps.invert(mask.convert("L")) # TODO: probably shouldn't invert mask here... should user be required to do it? @@ -214,22 +183,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: base_w, base_h = base_image.size new_image = new_image.crop((abs(min_x), abs(min_y), abs(min_x) + base_w, abs(min_y) + base_h)) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -237,7 +193,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Mask from Alpha", tags=["image", "mask"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): """Extracts the alpha channel of an image as a mask.""" @@ -245,29 +201,16 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] if self.invert: image_mask = ImageOps.invert(image_mask) - image_dto = context.services.images.create( - image=image_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -275,7 +218,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Multiply Images", tags=["image", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageMultiplyInvocation(BaseInvocation, WithMetadata): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" @@ -283,28 +226,15 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata): image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") - def invoke(self, context: InvocationContext) -> ImageOutput: - image1 = context.services.images.get_pil_image(self.image1.image_name) - image2 = context.services.images.get_pil_image(self.image2.image_name) + def invoke(self, context) -> ImageOutput: + image1 = context.images.get_pil(self.image1.image_name) + image2 = context.images.get_pil(self.image2.image_name) multiply_image = ImageChops.multiply(image1, image2) - image_dto = context.services.images.create( - image=multiply_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=multiply_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_CHANNELS = Literal["A", "R", "G", "B"] @@ -315,7 +245,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Extract Image Channel", tags=["image", "channel"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelInvocation(BaseInvocation, WithMetadata): """Gets a channel from an image.""" @@ -323,27 +253,14 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) - image_dto = context.services.images.create( - image=channel_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=channel_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] @@ -354,7 +271,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Convert Image Mode", tags=["image", "convert"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageConvertInvocation(BaseInvocation, WithMetadata): """Converts an image to a different mode.""" @@ -362,27 +279,14 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) - image_dto = context.services.images.create( - image=converted_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=converted_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -390,7 +294,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur Image", tags=["image", "blur"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageBlurInvocation(BaseInvocation, WithMetadata): """Blurs an image""" @@ -400,30 +304,17 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): # Metadata blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) blur = ( ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius) ) blur_image = image.filter(blur) - image_dto = context.services.images.create( - image=blur_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=blur_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -431,7 +322,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Unsharp Mask", tags=["image", "unsharp_mask"], category="image", - version="1.2.0", + version="1.2.1", classification=Classification.Beta, ) class UnsharpMaskInvocation(BaseInvocation, WithMetadata): @@ -447,8 +338,8 @@ def pil_from_array(self, arr): def array_from_pil(self, img): return numpy.array(img) / 255 - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) mode = image.mode alpha_channel = image.getchannel("A") if mode == "RGBA" else None @@ -466,16 +357,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if alpha_channel is not None: image.putalpha(alpha_channel) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) return ImageOutput( image=ImageField(image_name=image_dto.image_name), @@ -509,7 +391,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Resize Image", tags=["image", "resize"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageResizeInvocation(BaseInvocation, WithMetadata): """Resizes an image to specific dimensions""" @@ -519,8 +401,8 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height to resize to (px)") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -529,22 +411,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -552,7 +421,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Scale Image", tags=["image", "scale"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageScaleInvocation(BaseInvocation, WithMetadata): """Scales an image by a factor""" @@ -565,8 +434,8 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): ) resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width * self.scale_factor) @@ -577,22 +446,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -600,7 +456,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Lerp Image", tags=["image", "lerp"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageLerpInvocation(BaseInvocation, WithMetadata): """Linear interpolation of all pixels of an image""" @@ -609,30 +465,17 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = image_arr * (self.max - self.min) + self.min lerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=lerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=lerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -640,7 +483,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): """Inverse linear interpolation of all pixels of an image""" @@ -649,30 +492,17 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment] ilerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=ilerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=ilerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -680,17 +510,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur NSFW Image", tags=["image", "nsfw"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): """Add blur to NSFW-flagged images""" image: ImageField = InputField(description="The image to check") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) - logger = context.services.logger + logger = context.logger logger.debug("Running NSFW checker") if SafetyChecker.has_nsfw_concept(image): logger.info("A potentially NSFW image has been detected. Image will be blurred.") @@ -699,22 +529,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: blurry_image.paste(caution, (0, 0), caution) image = blurry_image - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) def _get_caution_img(self) -> Image.Image: import invokeai.app.assets.images as image_assets @@ -728,7 +545,7 @@ def _get_caution_img(self) -> Image.Image: title="Add Invisible Watermark", tags=["image", "watermark"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageWatermarkInvocation(BaseInvocation, WithMetadata): """Add an invisible watermark to an image""" @@ -736,25 +553,12 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -762,7 +566,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskEdgeInvocation(BaseInvocation, WithMetadata): """Applies an edge mask to an image""" @@ -775,8 +579,8 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): description="Second threshold for the hysteresis procedure in Canny edge detection" ) - def invoke(self, context: InvocationContext) -> ImageOutput: - mask = context.services.images.get_pil_image(self.image.image_name).convert("L") + def invoke(self, context) -> ImageOutput: + mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0))) @@ -791,22 +595,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: new_mask = ImageOps.invert(new_mask) - image_dto = context.services.images.create( - image=new_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -814,7 +605,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Combine Masks", tags=["image", "mask", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskCombineInvocation(BaseInvocation, WithMetadata): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" @@ -822,28 +613,15 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") - def invoke(self, context: InvocationContext) -> ImageOutput: - mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") - mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L") + def invoke(self, context) -> ImageOutput: + mask1 = context.images.get_pil(self.mask1.image_name).convert("L") + mask2 = context.images.get_pil(self.mask2.image_name).convert("L") combined_mask = ImageChops.multiply(mask1, mask2) - image_dto = context.services.images.create( - image=combined_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=combined_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -851,7 +629,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Color Correct", tags=["image", "color"], category="image", - version="1.2.0", + version="1.2.1", ) class ColorCorrectInvocation(BaseInvocation, WithMetadata): """ @@ -864,14 +642,14 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask_blur_radius: float = InputField(default=8, description="Mask blur radius") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: pil_init_mask = None if self.mask is not None: - pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L") + pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") - init_image = context.services.images.get_pil_image(self.reference.image_name) + init_image = context.images.get_pil(self.reference.image_name) - result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + result = context.images.get_pil(self.image.image_name).convert("RGBA") # if init_image is None or init_mask is None: # return result @@ -945,22 +723,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Paste original on color-corrected generation (using blurred mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) - image_dto = context.services.images.create( - image=matched_result, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=matched_result) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -968,7 +733,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Adjust Image Hue", tags=["image", "hue"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): """Adjusts the Hue of an image.""" @@ -976,8 +741,8 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space hsv_image = numpy.array(pil_image.convert("HSV")) @@ -991,24 +756,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to PIL format and to original color mode pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) COLOR_CHANNELS = Literal[ @@ -1072,7 +822,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: "value", ], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): """Add or subtract a value from a specific color channel of an image.""" @@ -1081,8 +831,8 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): channel: COLOR_CHANNELS = InputField(description="Which channel to adjust") offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1101,24 +851,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1143,7 +878,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: "value", ], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): """Scale a specific color channel of an image.""" @@ -1153,8 +888,8 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.") invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1177,24 +912,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - workflow=context.workflow, - metadata=self.metadata, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1202,7 +922,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Save Image", tags=["primitives", "image"], category="primitives", - version="1.2.0", + version="1.2.1", use_cache=False, ) class SaveImageInvocation(BaseInvocation, WithMetadata): @@ -1211,26 +931,12 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - board_id=self.board.board_id if self.board else None, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1238,7 +944,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Linear UI Image Output", tags=["primitives", "image"], category="primitives", - version="1.0.1", + version="1.0.2", use_cache=False, ) class LinearUIOutputInvocation(BaseInvocation, WithMetadata): @@ -1247,19 +953,13 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context: InvocationContext) -> ImageOutput: - image_dto = context.services.images.get_dto(self.image.image_name) - - if self.board: - context.services.board_images.add_image_to_board(self.board.board_id, self.image.image_name) + def invoke(self, context) -> ImageOutput: + image_dto = context.images.get_dto(self.image.image_name) - if image_dto.is_intermediate != self.is_intermediate: - context.services.images.update( - self.image.image_name, changes=ImageRecordChanges(is_intermediate=self.is_intermediate) - ) - - return ImageOutput( - image=ImageField(image_name=self.image.image_name), - width=image_dto.width, - height=image_dto.height, + image_dto = context.images.update( + image_name=self.image.image_name, + board_id=self.board.board_id if self.board else None, + is_intermediate=self.is_intermediate, ) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index d4d3d5bea44..be51c8312f9 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -6,15 +6,15 @@ import numpy as np from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ColorField, ImageField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, InvocationContext, invocation -from .fields import InputField, WithMetadata +from .baseinvocation import BaseInvocation, WithMetadata, invocation +from .fields import InputField from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES @@ -119,7 +119,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return si -@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class InfillColorInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image with a solid color""" @@ -129,33 +129,20 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): description="The color to use to infill", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") class InfillTileInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image with tiles of the image""" @@ -168,32 +155,19 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): description="The seed to use for tile generation (omit for random)", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( - "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0" + "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1" ) class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using the PatchMatch algorithm""" @@ -202,8 +176,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -228,77 +202,38 @@ def invoke(self, context: InvocationContext) -> ImageOutput: infilled.paste(image, (0, 0), mask=image.split()[-1]) # image.paste(infilled, (0, 0), mask=image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class LaMaInfillInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using the LaMa model""" image: ImageField = InputField(description="The image to infill") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class CV2InfillInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using OpenCV Inpainting""" image: ImageField = InputField(description="The image to infill") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index c01e0ed0fb2..b836be04b58 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -7,7 +7,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -62,7 +61,7 @@ class IPAdapterOutput(BaseInvocationOutput): ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -93,9 +92,9 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> IPAdapterOutput: + def invoke(self, context) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.model_info( + ip_adapter_info = context.models.get_info( self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter ) # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model @@ -104,7 +103,7 @@ def invoke(self, context: InvocationContext) -> IPAdapterOutput: # is currently messy due to differences between how the model info is generated when installing a model from # disk vs. downloading the model. image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"]) + os.path.join(context.config.get().models_path, ip_adapter_info["path"]) ) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model = CLIPVisionModelField( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 909c307481e..0127a6521e1 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Union import einops import numpy as np @@ -23,21 +23,26 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, WithMetadata +from invokeai.app.invocations.fields import ( + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIType, + WithMetadata, +) from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.primitives import ( - DenoiseMaskField, DenoiseMaskOutput, - ImageField, ImageOutput, - LatentsField, LatentsOutput, - build_latents_output, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo @@ -59,14 +64,15 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) -from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext + if choose_torch_device() == torch.device("mps"): from torch import mps @@ -102,7 +108,7 @@ class SchedulerInvocation(BaseInvocation): ui_type=UIType.Scheduler, ) - def invoke(self, context: InvocationContext) -> SchedulerOutput: + def invoke(self, context) -> SchedulerOutput: return SchedulerOutput(scheduler=self.scheduler) @@ -111,7 +117,7 @@ def invoke(self, context: InvocationContext) -> SchedulerOutput: title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", - version="1.0.0", + version="1.0.1", ) class CreateDenoiseMaskInvocation(BaseInvocation): """Creates mask for denoising model run.""" @@ -137,9 +143,9 @@ def prep_mask_tensor(self, mask_image): return mask_tensor @torch.no_grad() - def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + def invoke(self, context) -> DenoiseMaskOutput: if self.image is not None: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image = image_resized_to_grid_as_tensor(image.convert("RGB")) if image.dim() == 3: image = image.unsqueeze(0) @@ -147,47 +153,37 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: image = None mask = self.prep_mask_tensor( - context.services.images.get_pil_image(self.mask.image_name), + context.images.get_pil(self.mask.image_name), ) if image is not None: - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents" - context.services.latents.save(masked_latents_name, masked_latents) + masked_latents_name = context.latents.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" - context.services.latents.save(mask_name, mask) + mask_name = context.latents.save(tensor=mask) - return DenoiseMaskOutput( - denoise_mask=DenoiseMaskField( - mask_name=mask_name, - masked_latents_name=masked_latents_name, - ), + return DenoiseMaskOutput.build( + mask_name=mask_name, + masked_latents_name=masked_latents_name, ) def get_scheduler( - context: InvocationContext, + context: "InvocationContext", scheduler_info: ModelInfo, scheduler_name: str, seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.model_dump(), - context=context, - ) + orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -216,7 +212,7 @@ def get_scheduler( title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", - version="1.5.1", + version="1.5.2", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" @@ -302,34 +298,18 @@ def ge_one(cls, v): raise ValueError("cfg_scale must be greater than 1") return v - # TODO: pass this an emitter method or something? or a session for dispatching? - def dispatch_progress( - self, - context: InvocationContext, - source_node_id: str, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - base_model=base_model, - ) - def get_conditioning_data( self, - context: InvocationContext, + context: "InvocationContext", scheduler, unet, seed, ) -> ConditioningData: - positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -389,7 +369,7 @@ def __init__(self): def prep_control_data( self, - context: InvocationContext, + context: "InvocationContext", control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, @@ -417,17 +397,16 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_manager.get_model( + context.models.load( model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, - context=context, ) ) # control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.images.get_pil(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -463,7 +442,7 @@ def prep_control_data( def prep_ip_adapter_data( self, - context: InvocationContext, + context: "InvocationContext", ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], conditioning_data: ConditioningData, exit_stack: ExitStack, @@ -485,19 +464,17 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( + context.models.load( model_name=single_ip_adapter.ip_adapter_model.model_name, model_type=ModelType.IPAdapter, base_model=single_ip_adapter.ip_adapter_model.base_model, - context=context, ) ) - image_encoder_model_info = context.services.model_manager.get_model( + image_encoder_model_info = context.models.load( model_name=single_ip_adapter.image_encoder_model.model_name, model_type=ModelType.CLIPVision, base_model=single_ip_adapter.image_encoder_model.base_model, - context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. @@ -505,7 +482,7 @@ def prep_ip_adapter_data( if not isinstance(single_ipa_images, list): single_ipa_images = [single_ipa_images] - single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -532,7 +509,7 @@ def prep_ip_adapter_data( def run_t2i_adapters( self, - context: InvocationContext, + context: "InvocationContext", t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, @@ -549,13 +526,12 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( + t2i_adapter_model_info = context.models.load( model_name=t2i_adapter_field.t2i_adapter_model.model_name, model_type=ModelType.T2IAdapter, base_model=t2i_adapter_field.t2i_adapter_model.base_model, - context=context, ) - image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) + image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: @@ -642,30 +618,30 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context, latents): + def prep_inpaint_mask(self, context: "InvocationContext", latents): if self.denoise_mask is None: return None, None - mask = context.services.latents.get(self.denoise_mask.mask_name) + mask = context.latents.get(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) else: masked_latents = None return 1 - mask, masked_latents @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None if self.noise is not None: - noise = context.services.latents.get(self.noise.latents_name) + noise = context.latents.get(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.latents.get(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -691,27 +667,17 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + context.util.sd_step_callback(state, self.unet.unet.base_model) def _lora_loader(): for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), - context=context, - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model( - **self.unet.unet.model_dump(), - context=context, - ) + unet_info = context.models.load(**self.unet.unet.model_dump()) with ( ExitStack() as exit_stack, ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), @@ -787,9 +753,8 @@ def _lora_loader(): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, result_latents) - return build_latents_output(latents_name=name, latents=result_latents, seed=seed) + name = context.latents.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @invocation( @@ -797,7 +762,7 @@ def _lora_loader(): title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", - version="1.2.0", + version="1.2.1", ) class LatentsToImageInvocation(BaseInvocation, WithMetadata): """Generates an image from latents.""" @@ -814,13 +779,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> ImageOutput: + latents = context.latents.get(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: latents = latents.to(vae.device) @@ -849,7 +811,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae.to(dtype=torch.float16) latents = latents.half() - if self.tiled or context.services.configuration.tiled_decode: + if self.tiled or context.config.get().tiled_decode: vae.enable_tiling() else: vae.disable_tiling() @@ -873,22 +835,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] @@ -899,7 +848,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Resize Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ResizeLatentsInvocation(BaseInvocation): """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" @@ -921,8 +870,8 @@ class ResizeLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -940,10 +889,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.latents.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -951,7 +898,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: title="Scale Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ScaleLatentsInvocation(BaseInvocation): """Scales latents by a given factor.""" @@ -964,8 +911,8 @@ class ScaleLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -984,10 +931,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.latents.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -995,7 +940,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", - version="1.0.0", + version="1.0.1", ) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -1055,13 +1000,10 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): return latents @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: @@ -1069,10 +1011,9 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) - name = f"{context.graph_execution_state_id}__{self.id}" latents = latents.to("cpu") - context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=latents, seed=None) + name = context.latents.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @staticmethod @@ -1092,7 +1033,7 @@ def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTenso title="Blend Latents", tags=["latents", "blend"], category="latents", - version="1.0.0", + version="1.0.1", ) class BlendLatentsInvocation(BaseInvocation): """Blend two latents using a given alpha. Latents must have same size.""" @@ -1107,9 +1048,9 @@ class BlendLatentsInvocation(BaseInvocation): ) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.services.latents.get(self.latents_a.latents_name) - latents_b = context.services.latents.get(self.latents_b.latents_name) + def invoke(self, context) -> LatentsOutput: + latents_a = context.latents.get(self.latents_a.latents_name) + latents_b = context.latents.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1163,10 +1104,8 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, blended_latents) - return build_latents_output(latents_name=name, latents=blended_latents) + name = context.latents.save(tensor=blended_latents) + return LatentsOutput.build(latents_name=name, latents=blended_latents) # The Crop Latents node was copied from @skunkworxdark's implementation here: @@ -1176,7 +1115,7 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): title="Crop Latents", tags=["latents", "crop"], category="latents", - version="1.0.0", + version="1.0.1", ) # TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. # Currently, if the class names conflict then 'GET /openapi.json' fails. @@ -1210,8 +1149,8 @@ class CropLatentsCoreInvocation(BaseInvocation): description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", ) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1220,10 +1159,9 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: cropped_latents = latents[..., y1:y2, x1:x2] - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, cropped_latents) + name = context.latents.save(tensor=cropped_latents) - return build_latents_output(latents_name=name, latents=cropped_latents) + return LatentsOutput.build(latents_name=name, latents=cropped_latents) @invocation_output("ideal_size_output") diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 6ca53011f0b..d2dbf049816 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") @@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a + self.b) @@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a - self.b) @@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a * self.b) @@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=int(self.a / self.b)) @@ -69,7 +69,7 @@ class RandomIntInvocation(BaseInvocation): low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) @@ -88,7 +88,7 @@ class RandomFloatInvocation(BaseInvocation): high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: random_float = np.random.uniform(self.low, self.high) rounded_float = round(random_float, self.decimals) return FloatOutput(value=rounded_float) @@ -110,7 +110,7 @@ class FloatToIntegerInvocation(BaseInvocation): default="Nearest", description="The method to use for rounding" ) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: if self.method == "Nearest": return IntegerOutput(value=round(self.value / self.multiple) * self.multiple) elif self.method == "Floor": @@ -128,7 +128,7 @@ class RoundInvocation(BaseInvocation): value: float = InputField(default=0, description="The float value") decimals: int = InputField(default=0, description="The number of decimal places") - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: return FloatOutput(value=round(self.value, self.decimals)) @@ -196,7 +196,7 @@ def no_unrepresentable_results(cls, v: int, info: ValidationInfo): raise ValueError("Result of exponentiation is not an integer") return v - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return IntegerOutput(value=self.a + self.b) @@ -270,7 +270,7 @@ def no_unrepresentable_results(cls, v: float, info: ValidationInfo): raise ValueError("Root operation resulted in a complex number") return v - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return FloatOutput(value=self.a + self.b) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 399e217dc17..9d74abd8c12 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -5,15 +5,20 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import ControlField -from invokeai.app.invocations.fields import FieldDescriptions, InputField, MetadataField, OutputField, UIType +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + InputField, + MetadataField, + OutputField, + UIType, +) from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField -from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.t2i_adapter import T2IAdapterField from ...version import __version__ @@ -59,7 +64,7 @@ class MetadataItemInvocation(BaseInvocation): label: str = InputField(description=FieldDescriptions.metadata_item_label) value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any) - def invoke(self, context: InvocationContext) -> MetadataItemOutput: + def invoke(self, context) -> MetadataItemOutput: return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value)) @@ -76,7 +81,7 @@ class MetadataInvocation(BaseInvocation): description=FieldDescriptions.metadata_item_polymorphic ) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: if isinstance(self.items, MetadataItemField): # single metadata item data = {self.items.label: self.items.value} @@ -95,7 +100,7 @@ class MergeMetadataInvocation(BaseInvocation): collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: data = {} for item in self.collection: data.update(item.model_dump()) @@ -213,7 +218,7 @@ class CoreMetadataInvocation(BaseInvocation): description="The start value used for refiner denoising", ) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" return MetadataOutput( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c710c9761b0..f81e559e446 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -10,7 +10,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -102,7 +101,7 @@ class LoRAModelField(BaseModel): title="Main Model", tags=["model"], category="model", - version="1.0.0", + version="1.0.1", ) class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -110,13 +109,13 @@ class MainModelLoaderInvocation(BaseInvocation): model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: + def invoke(self, context) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -203,7 +202,7 @@ class LoraLoaderOutput(BaseInvocationOutput): clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -222,14 +221,14 @@ class LoraLoaderInvocation(BaseInvocation): title="CLIP", ) - def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + def invoke(self, context) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") base_model = self.lora.base_model lora_name = self.lora.model_name - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=lora_name, model_type=ModelType.Lora, @@ -285,7 +284,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput): title="SDXL LoRA", tags=["lora", "model"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -311,14 +310,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation): title="CLIP 2", ) - def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: + def invoke(self, context) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") base_model = self.lora.base_model lora_name = self.lora.model_name - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=lora_name, model_type=ModelType.Lora, @@ -384,7 +383,7 @@ class VAEModelField(BaseModel): model_config = ConfigDict(protected_namespaces=()) -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" @@ -394,12 +393,12 @@ class VaeLoaderInvocation(BaseInvocation): title="VAE", ) - def invoke(self, context: InvocationContext) -> VAEOutput: + def invoke(self, context) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=model_name, model_type=model_type, @@ -449,7 +448,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context: InvocationContext) -> SeamlessModeOutput: + def invoke(self, context) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) @@ -485,6 +484,6 @@ class FreeUInvocation(BaseInvocation): s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) - def invoke(self, context: InvocationContext) -> UNetOutput: + def invoke(self, context) -> UNetOutput: self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 2e717ac561b..41641152f04 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,15 +4,13 @@ import torch from pydantic import field_validator -from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField -from invokeai.app.invocations.latent import LatentsField +from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -67,13 +65,13 @@ class NoiseOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) - -def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): - return NoiseOutput( - noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": + return cls( + noise=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) @invocation( @@ -114,7 +112,7 @@ def modulo_seed(cls, v): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) - def invoke(self, context: InvocationContext) -> NoiseOutput: + def invoke(self, context) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, @@ -122,6 +120,5 @@ def invoke(self, context: InvocationContext) -> NoiseOutput: seed=self.seed, use_cpu=self.use_cpu, ) - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, noise) - return build_noise_output(latents_name=name, latents=noise, seed=self.seed) + name = context.latents.save(tensor=noise) + return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index b43d7eaef2c..3f8e6669ab8 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -37,7 +37,7 @@ invocation_output, ) from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler +from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler from .model import ClipField, ModelInfo, UNetField, VaeField ORT_TO_NP_TYPE = { @@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.model_dump(), ) @@ -201,7 +201,7 @@ def ge_one(cls, v): # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): ) # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: @@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation): description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel ) - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: + def invoke(self, context) -> ONNXModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.ONNX diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index dab9c3dc0f4..bf59e87d270 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -41,7 +41,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -62,7 +62,7 @@ class FloatLinearRangeInvocation(BaseInvocation): description="number of values to interpolate over (including start and stop)", ) - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) @@ -110,7 +110,7 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: title="Step Param Easing", tags=["step", "easing"], category="step", - version="1.0.0", + version="1.0.1", ) class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" @@ -130,7 +130,7 @@ class StepParamEasingInvocation(BaseInvocation): # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = InputField(default=False, description="show easing plot") - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) @@ -149,19 +149,19 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: postlist = list(num_poststeps * [self.post_end_value]) if log_diagnostics: - context.services.logger.debug("start_step: " + str(start_step)) - context.services.logger.debug("end_step: " + str(end_step)) - context.services.logger.debug("num_easing_steps: " + str(num_easing_steps)) - context.services.logger.debug("num_presteps: " + str(num_presteps)) - context.services.logger.debug("num_poststeps: " + str(num_poststeps)) - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("postlist size: " + str(len(postlist))) - context.services.logger.debug("prelist: " + str(prelist)) - context.services.logger.debug("postlist: " + str(postlist)) + context.logger.debug("start_step: " + str(start_step)) + context.logger.debug("end_step: " + str(end_step)) + context.logger.debug("num_easing_steps: " + str(num_easing_steps)) + context.logger.debug("num_presteps: " + str(num_presteps)) + context.logger.debug("num_poststeps: " + str(num_poststeps)) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist: " + str(prelist)) + context.logger.debug("postlist: " + str(postlist)) easing_class = EASING_FUNCTIONS_MAP[self.easing] if log_diagnostics: - context.services.logger.debug("easing class: " + str(easing_class)) + context.logger.debug("easing class: " + str(easing_class)) easing_list = [] if self.mirror: # "expected" mirroring # if number of steps is even, squeeze duration down to (number_of_steps)/2 @@ -172,7 +172,7 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: base_easing_duration = int(np.ceil(num_easing_steps / 2.0)) if log_diagnostics: - context.services.logger.debug("base easing duration: " + str(base_easing_duration)) + context.logger.debug("base easing duration: " + str(base_easing_duration)) even_num_steps = num_easing_steps % 2 == 0 # even number of steps easing_function = easing_class( start=self.start_value, @@ -184,14 +184,14 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: easing_val = easing_function.ease(step_index) base_easing_vals.append(easing_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) if even_num_steps: mirror_easing_vals = list(reversed(base_easing_vals)) else: mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) if log_diagnostics: - context.services.logger.debug("base easing vals: " + str(base_easing_vals)) - context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) + context.logger.debug("base easing vals: " + str(base_easing_vals)) + context.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) easing_list = base_easing_vals + mirror_easing_vals # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely @@ -226,12 +226,12 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: step_val = easing_function.ease(step_index) easing_list.append(step_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) if log_diagnostics: - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("easing_list size: " + str(len(easing_list))) - context.services.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("easing_list size: " + str(len(easing_list))) + context.logger.debug("postlist size: " + str(len(postlist))) param_list = prelist + easing_list + postlist diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 22f03454a55..ee04345eed8 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -1,16 +1,26 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -from typing import Optional, Tuple +from typing import Optional import torch -from pydantic import BaseModel, Field -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent +from invokeai.app.invocations.fields import ( + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIComponent, +) +from invokeai.app.services.images.images_common import ImageDTO from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -49,7 +59,7 @@ class BooleanInvocation(BaseInvocation): value: bool = InputField(default=False, description="The boolean value") - def invoke(self, context: InvocationContext) -> BooleanOutput: + def invoke(self, context) -> BooleanOutput: return BooleanOutput(value=self.value) @@ -65,7 +75,7 @@ class BooleanCollectionInvocation(BaseInvocation): collection: list[bool] = InputField(default=[], description="The collection of boolean values") - def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: + def invoke(self, context) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -98,7 +108,7 @@ class IntegerInvocation(BaseInvocation): value: int = InputField(default=0, description="The integer value") - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.value) @@ -114,7 +124,7 @@ class IntegerCollectionInvocation(BaseInvocation): collection: list[int] = InputField(default=[], description="The collection of integer values") - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -145,7 +155,7 @@ class FloatInvocation(BaseInvocation): value: float = InputField(default=0.0, description="The float value") - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: return FloatOutput(value=self.value) @@ -161,7 +171,7 @@ class FloatCollectionInvocation(BaseInvocation): collection: list[float] = InputField(default=[], description="The collection of float values") - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -192,7 +202,7 @@ class StringInvocation(BaseInvocation): value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=self.value) @@ -208,7 +218,7 @@ class StringCollectionInvocation(BaseInvocation): collection: list[str] = InputField(default=[], description="The collection of string values") - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -217,18 +227,6 @@ def invoke(self, context: InvocationContext) -> StringCollectionOutput: # region Image -class ImageField(BaseModel): - """An image primitive field""" - - image_name: str = Field(description="The name of the image") - - -class BoardField(BaseModel): - """A board primitive field""" - - board_id: str = Field(description="The id of the board") - - @invocation_output("image_output") class ImageOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" @@ -237,6 +235,14 @@ class ImageOutput(BaseInvocationOutput): width: int = OutputField(description="The width of the image in pixels") height: int = OutputField(description="The height of the image in pixels") + @classmethod + def build(cls, image_dto: ImageDTO) -> "ImageOutput": + return cls( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + @invocation_output("image_collection_output") class ImageCollectionOutput(BaseInvocationOutput): @@ -247,7 +253,7 @@ class ImageCollectionOutput(BaseInvocationOutput): ) -@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0") +@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1") class ImageInvocation( BaseInvocation, ): @@ -255,8 +261,8 @@ class ImageInvocation( image: ImageField = InputField(description="The image to load") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) return ImageOutput( image=ImageField(image_name=self.image.image_name), @@ -277,7 +283,7 @@ class ImageCollectionInvocation(BaseInvocation): collection: list[ImageField] = InputField(description="The collection of image values") - def invoke(self, context: InvocationContext) -> ImageCollectionOutput: + def invoke(self, context) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -286,32 +292,24 @@ def invoke(self, context: InvocationContext) -> ImageCollectionOutput: # region DenoiseMask -class DenoiseMaskField(BaseModel): - """An inpaint mask field""" - - mask_name: str = Field(description="The name of the mask image") - masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - - @invocation_output("denoise_mask_output") class DenoiseMaskOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") + @classmethod + def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput": + return cls( + denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name), + ) + # endregion # region Latents -class LatentsField(BaseModel): - """A latents tensor primitive field""" - - latents_name: str = Field(description="The name of the latents") - seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") - - @invocation_output("latents_output") class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" @@ -322,6 +320,14 @@ class LatentsOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": + return cls( + latents=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + @invocation_output("latents_collection_output") class LatentsCollectionOutput(BaseInvocationOutput): @@ -333,17 +339,17 @@ class LatentsCollectionOutput(BaseInvocationOutput): @invocation( - "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0" + "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1" ) class LatentsInvocation(BaseInvocation): """A latents tensor primitive value""" latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) - return build_latents_output(self.latents.latents_name, latents) + return LatentsOutput.build(self.latents.latents_name, latents) @invocation( @@ -360,35 +366,15 @@ class LatentsCollectionInvocation(BaseInvocation): description="The collection of latents tensors", ) - def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: + def invoke(self, context) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) -def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) - - # endregion # region Color -class ColorField(BaseModel): - """A color primitive field""" - - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) - - @invocation_output("color_output") class ColorOutput(BaseInvocationOutput): """Base class for nodes that output a single color""" @@ -411,7 +397,7 @@ class ColorInvocation(BaseInvocation): color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") - def invoke(self, context: InvocationContext) -> ColorOutput: + def invoke(self, context) -> ColorOutput: return ColorOutput(color=self.color) @@ -420,18 +406,16 @@ def invoke(self, context: InvocationContext) -> ColorOutput: # region Conditioning -class ConditioningField(BaseModel): - """A conditioning tensor primitive value""" - - conditioning_name: str = Field(description="The name of conditioning tensor") - - @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) + @classmethod + def build(cls, conditioning_name: str) -> "ConditioningOutput": + return cls(conditioning=ConditioningField(conditioning_name=conditioning_name)) + @invocation_output("conditioning_collection_output") class ConditioningCollectionOutput(BaseInvocationOutput): @@ -454,7 +438,7 @@ class ConditioningInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: return ConditioningOutput(conditioning=self.conditioning) @@ -473,7 +457,7 @@ class ConditioningCollectionInvocation(BaseInvocation): description="The collection of conditioning tensors", ) - def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: + def invoke(self, context) -> ConditioningCollectionOutput: return ConditioningCollectionOutput(collection=self.collection) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 94b4a217ae7..4f5ef43a568 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, UIComponent @@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation): max_prompts: int = InputField(default=1, description="The number of prompts to generate") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -91,7 +91,7 @@ def promptsFromFile( break return prompts - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 62df5bc8047..75a526cfff6 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -4,7 +4,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -30,7 +29,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" @@ -39,13 +38,13 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: + def invoke(self, context) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -116,7 +115,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" @@ -128,13 +127,13 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: + def invoke(self, context) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index ccbc2f6d924..a4c92d9de56 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -5,7 +5,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -33,7 +32,7 @@ class StringSplitNegInvocation(BaseInvocation): string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringPosNegOutput: + def invoke(self, context) -> StringPosNegOutput: p_string = "" n_string = "" brackets_depth = 0 @@ -77,7 +76,7 @@ class StringSplitInvocation(BaseInvocation): default="", description="Delimiter to spilt with. blank will split on the first whitespace" ) - def invoke(self, context: InvocationContext) -> String2Output: + def invoke(self, context) -> String2Output: result = self.string.split(self.delimiter, 1) if len(result) == 2: part1, part2 = result @@ -95,7 +94,7 @@ class StringJoinInvocation(BaseInvocation): string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) @@ -107,7 +106,7 @@ class StringJoinThreeInvocation(BaseInvocation): string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or ""))) @@ -126,7 +125,7 @@ class StringReplaceInvocation(BaseInvocation): default=False, description="Use search string as a regex expression (non regex is case insensitive)" ) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: pattern = self.search_string or "" new_string = self.string or "" if len(pattern) > 0: diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 66ac87c37b8..74a098a501c 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -5,13 +5,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField -from invokeai.app.invocations.primitives import ImageField +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.backend.model_management.models.base import BaseModelType @@ -91,7 +89,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> T2IAdapterOutput: + def invoke(self, context) -> T2IAdapterOutput: return T2IAdapterOutput( t2i_adapter=T2IAdapterField( image=self.image, diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index bdc23ef6edd..dd34c3dc093 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,13 +8,12 @@ BaseInvocation, BaseInvocationOutput, Classification, - InvocationContext, + WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import Input, InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -58,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation): description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", ) - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_with_overlap( image_height=self.image_height, image_width=self.image_width, @@ -101,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation): description="The overlap, in pixels, between adjacent tiles.", ) - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_even_split( image_height=self.image_height, image_width=self.image_width, @@ -131,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_min_overlap( image_height=self.image_height, image_width=self.image_width, @@ -176,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation): tile: Tile = InputField(description="The tile to split into properties.") - def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: + def invoke(self, context) -> TileToPropertiesOutput: return TileToPropertiesOutput( coords_left=self.tile.coords.left, coords_right=self.tile.coords.right, @@ -213,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation): image: ImageField = InputField(description="The tile image.") tile: Tile = InputField(description="The tile properties.") - def invoke(self, context: InvocationContext) -> PairTileImageOutput: + def invoke(self, context) -> PairTileImageOutput: return PairTileImageOutput( tile_with_image=TileWithImage( tile=self.tile, @@ -249,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: images = [twi.image for twi in self.tiles_with_images] tiles = [twi.tile for twi in self.tiles_with_images] @@ -265,7 +264,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # existed in memory at an earlier point in the graph. tile_np_images: list[np.ndarray] = [] for image in images: - pil_image = context.services.images.get_pil_image(image.image_name) + pil_image = context.images.get_pil(image.image_name) pil_image = pil_image.convert("RGB") tile_np_images.append(np.array(pil_image)) @@ -288,18 +287,5 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert into a PIL image and save pil_image = Image.fromarray(np_image) - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=pil_image) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 2cab279a9fc..ef174809860 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -8,13 +8,13 @@ from PIL import Image from pydantic import ConfigDict -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata # TODO: Populate this from disk? @@ -30,7 +30,7 @@ from torch import mps -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0") +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1") class ESRGANInvocation(BaseInvocation, WithMetadata): """Upscales an image using RealESRGAN.""" @@ -42,9 +42,9 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - models_path = context.services.configuration.models_path + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + models_path = context.config.get().models_path rrdbnet_model = None netscale = None @@ -88,7 +88,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: netscale = 2 else: msg = f"Invalid RealESRGAN model: {self.model_name}" - context.services.logger.error(msg) + context.logger.error(msg) raise ValueError(msg) esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") @@ -111,19 +111,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index e9365f33495..ad08ae03956 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -55,7 +55,7 @@ def emit_generator_progress( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - node: dict, + node_id: str, source_node_id: str, progress_image: Optional[ProgressImage], step: int, @@ -70,7 +70,7 @@ def emit_generator_progress( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "node_id": node.get("id"), + "node_id": node_id, "source_node_id": source_node_id, "progress_image": progress_image.model_dump() if progress_image is not None else None, "step": step, diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index 54342c0da13..d2ebe235e63 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -5,11 +5,11 @@ from typing import Optional import invokeai.backend.util.logging as logger -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem from invokeai.app.services.invocation_stats.invocation_stats_common import ( GESStatsNotFoundError, ) +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler from ..invoker import Invoker @@ -131,16 +131,20 @@ def stats_cleanup(graph_execution_state_id: str) -> None: # which handles a few things: # - nodes that require a value, but get it only from a connection # - referencing the invocation cache instead of executing the node - outputs = invocation.invoke_internal( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - queue_batch_id=queue_item.session_queue_batch_id, - workflow=queue_item.workflow, - ) + context_data = InvocationContextData( + invocation=invocation, + session_id=graph_id, + workflow=queue_item.workflow, + source_node_id=source_node_id, + queue_id=queue_item.session_queue_id, + queue_item_id=queue_item.session_queue_item_id, + batch_id=queue_item.session_queue_batch_id, + ) + context = build_invocation_context( + services=self.__invoker.services, + context_data=context_data, ) + outputs = invocation.invoke_internal(context=context, services=self.__invoker.services) # Check queue to see if this is canceled, and skip if so if self.__invoker.services.queue.is_canceled(graph_execution_state.id): diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 4c2fc4c085c..a9b53ae2242 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -5,11 +5,12 @@ from abc import ABC, abstractmethod from logging import Logger from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union from pydantic import Field from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -21,9 +22,6 @@ ) from invokeai.backend.model_management.model_cache import CacheStats -if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext - class ModelManagerServiceBase(ABC): """Responsible for managing models on disk and in memory""" @@ -49,8 +47,7 @@ def get_model( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> ModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) diff --git a/invokeai/app/services/model_manager/model_manager_common.py b/invokeai/app/services/model_manager/model_manager_common.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index cdb3e59a91c..b641dd3f1ed 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -11,6 +11,8 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -30,7 +32,7 @@ from .model_manager_base import ModelManagerServiceBase if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import InvocationContext + pass # simple implementation @@ -86,13 +88,16 @@ def __init__( ) logger.info("Model manager service initialized") + def start(self, invoker: Invoker) -> None: + self._invoker: Optional[Invoker] = invoker + def get_model( self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> ModelInfo: """ Retrieve the indicated model. submodel can be used to get a @@ -100,9 +105,9 @@ def get_model( """ # we can emit model loading events if we are executing with access to the invocation context - if context: + if context_data is not None: self._emit_load_event( - context=context, + context_data=context_data, model_name=model_name, base_model=base_model, model_type=model_type, @@ -116,9 +121,9 @@ def get_model( submodel, ) - if context: + if context_data is not None: self._emit_load_event( - context=context, + context_data=context_data, model_name=model_name, base_model=base_model, model_type=model_type, @@ -263,22 +268,25 @@ def commit(self, conf_file: Optional[Path] = None): def _emit_load_event( self, - context: InvocationContext, + context_data: InvocationContextData, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, model_info: Optional[ModelInfo] = None, ): - if context.services.queue.is_canceled(context.graph_execution_state_id): + if self._invoker is None: + return + + if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() if model_info: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_completed( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -286,11 +294,11 @@ def _emit_load_event( model_info=model_info, ) else: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_started( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index ba05b050c5b..c0699eb96bb 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -13,7 +13,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -202,7 +201,7 @@ class GraphInvocation(BaseInvocation): # TODO: figure out how to create a default here graph: "Graph" = InputField(description="The graph to run", default=None) - def invoke(self, context: InvocationContext) -> GraphInvocationOutput: + def invoke(self, context) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() @@ -228,7 +227,7 @@ class IterateInvocation(BaseInvocation): ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) - def invoke(self, context: InvocationContext) -> IterateInvocationOutput: + def invoke(self, context) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @@ -255,7 +254,7 @@ class CollectInvocation(BaseInvocation): description="The collection, will be provided on execution", default=[], ui_hidden=True ) - def invoke(self, context: InvocationContext) -> CollectInvocationOutput: + def invoke(self, context) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c0aaac54f87..b68e521c73f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,8 +6,7 @@ from pydantic import ConfigDict from torch import Tensor -from invokeai.app.invocations.compel import ConditioningFieldData -from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -245,13 +244,15 @@ def save(conditioning_data: ConditioningFieldData) -> str: ) return name - def get(conditioning_name: str) -> Tensor: + def get(conditioning_name: str) -> ConditioningFieldData: """ Gets conditioning data by name. :param conditioning_name: The name of the conditioning data to get. """ - return services.latents.get(conditioning_name) + # TODO(sm): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed as returning tensors, so we need to ignore the type here. + return services.latents.get(conditioning_name) # type: ignore [return-value] self.save = save self.get = get diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 5cc3caa9ba5..d83b380d95d 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,25 +1,18 @@ -from typing import Protocol +from typing import TYPE_CHECKING import torch from PIL import Image -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage -from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC -from invokeai.app.services.shared.invocation_context import InvocationContextData from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL - -class StepCallback(Protocol): - def __call__( - self, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - ... +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC + from invokeai.app.services.shared.invocation_context import InvocationContextData def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -38,11 +31,11 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( - context_data: InvocationContextData, + context_data: "InvocationContextData", intermediate_state: PipelineIntermediateState, base_model: BaseModelType, - invocation_queue: InvocationQueueABC, - events: EventServiceBase, + invocation_queue: "InvocationQueueABC", + events: "EventServiceBase", ) -> None: if invocation_queue.is_canceled(context_data.session_id): raise CanceledException From 4aa7bee4b9c4235e4118092be69aa5bf88237722 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:27 +1100 Subject: [PATCH 05/67] docs: update INVOCATIONS.md --- docs/contributing/INVOCATIONS.md | 97 ++++++++++++-------------------- 1 file changed, 36 insertions(+), 61 deletions(-) diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index 124589f44ce..5d9a3690bad 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -9,11 +9,15 @@ complex functionality. ## Invocations Directory -InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These can be used as examples to create your own nodes. +InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These +can be used as examples to create your own nodes. -New nodes should be added to a subfolder in `nodes` direction found at the root level of the InvokeAI installation location. Nodes added to this folder will be able to be used upon application startup. +New nodes should be added to a subfolder in `nodes` direction found at the root +level of the InvokeAI installation location. Nodes added to this folder will be +able to be used upon application startup. + +Example `nodes` subfolder structure: -Example `nodes` subfolder structure: ```py ├── __init__.py # Invoke-managed custom node loader │ @@ -30,14 +34,14 @@ Example `nodes` subfolder structure: └── fancy_node.py ``` -Each node folder must have an `__init__.py` file that imports its nodes. Only nodes imported in the `__init__.py` file are loaded. - See the README in the nodes folder for more examples: +Each node folder must have an `__init__.py` file that imports its nodes. Only +nodes imported in the `__init__.py` file are loaded. See the README in the nodes +folder for more examples: ```py from .cool_node import CoolInvocation ``` - ## Creating A New Invocation In order to understand the process of creating a new Invocation, let us actually @@ -131,7 +135,6 @@ from invokeai.app.invocations.primitives import ImageField class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") @@ -167,12 +170,11 @@ from invokeai.app.invocations.primitives import ImageField class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext): + def invoke(self, context): pass ``` @@ -197,12 +199,11 @@ from invokeai.app.invocations.image import ImageOutput class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: pass ``` @@ -228,31 +229,18 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext) -> ImageOutput: - # Load the image using InvokeAI's predefined Image Service. Returns the PIL image. - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + # Load the input image as a PIL image + image = context.images.get_pil(self.image.image_name) - # Resizing the image + # Resize the image resized_image = image.resize((self.width, self.height)) - # Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image. - output_image = context.services.images.create( - image=resized_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) - - # Returning the Image - return ImageOutput( - image=ImageField( - image_name=output_image.image_name, - ), - width=output_image.width, - height=output_image.height, - ) + # Save the image + image_dto = context.images.save(image=resized_image) + + # Return an ImageOutput + return ImageOutput.build(image_dto) ``` **Note:** Do not be overwhelmed by the `ImageOutput` process. InvokeAI has a @@ -343,27 +331,25 @@ class ImageColorStringOutput(BaseInvocationOutput): That's all there is to it. - +Custom fields only support connection inputs in the Workflow Editor. From ac2eb16a658e4e119a6d17fecaaa899ea13eff0f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:38 +1100 Subject: [PATCH 06/67] tests: fix tests for new invocation context --- tests/aa_nodes/test_graph_execution_state.py | 8 +------- tests/aa_nodes/test_nodes.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index fab1fa4598f..9cc30e43e11 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -21,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService -from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -86,12 +85,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( InvocationContext( - queue_batch_id="1", - queue_item_id=1, - queue_id=DEFAULT_QUEUE_ID, - services=services, - graph_execution_state_id="1", - workflow=None, + conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None ) ) g.complete(n.id, o) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index e71daad3f3a..559457c0e11 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -3,7 +3,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -21,7 +20,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocation(BaseInvocation): collection: list[ImageField] = InputField(default=[]) - def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: + def invoke(self, context) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -34,13 +33,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocation(BaseInvocation): prompt: str = InputField(default="") - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + def invoke(self, context) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @invocation("test_error", version="1.0.0") class ErrorInvocation(BaseInvocation): - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + def invoke(self, context) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") @@ -54,7 +53,7 @@ class TextToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") prompt2: str = InputField(default="") - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + def invoke(self, context) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -63,7 +62,7 @@ class ImageToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") image: Union[ImageField, None] = InputField(default=None) - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + def invoke(self, context) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -76,7 +75,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocation(BaseInvocation): collection: list[str] = InputField() - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + def invoke(self, context) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -89,7 +88,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocation(BaseInvocation): value: Any = InputField(default=None) - def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: + def invoke(self, context) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -97,7 +96,7 @@ def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: class PolymorphicStringTestInvocation(BaseInvocation): value: Union[str, list[str]] = InputField(default="") - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + def invoke(self, context) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str): return PromptCollectionTestInvocationOutput(collection=[self.value]) return PromptCollectionTestInvocationOutput(collection=self.value) From 8baf3f78a29f5baf94753bfd11b5ddd19e7b0f63 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:05:15 +1100 Subject: [PATCH 07/67] feat(nodes): tidy `invocation_context.py`, improve comments --- .../app/services/shared/invocation_context.py | 115 ++++++++++++------ 1 file changed, 80 insertions(+), 35 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index b68e521c73f..7961c011aff 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Optional from PIL.Image import Image @@ -37,6 +36,9 @@ When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere with each other. +Many of the wrappers have the same signature as the methods they wrap. This allows us to write +user-facing docstrings and not need to go and update the internal services to match. + Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. """ @@ -44,12 +46,19 @@ @dataclass(frozen=True) class InvocationContextData: invocation: "BaseInvocation" + """The invocation that is being executed.""" session_id: str + """The session that is being executed.""" queue_id: str + """The queue in which the session is being executed.""" source_node_id: str + """The ID of the node from which the currently executing invocation was prepared.""" queue_item_id: int + """The ID of the queue item that is being executed.""" batch_id: str + """The ID of the batch that is being executed.""" workflow: Optional[WorkflowWithoutID] = None + """The workflow associated with this queue item, if any.""" class LoggerInterface: @@ -103,14 +112,15 @@ def save( """ Saves an image, returning its DTO. - If the current queue item has a workflow, it is automatically saved with the image. + If the current queue item has a workflow or metadata, it is automatically saved with the image. :param image: The image to save, as a PIL image. :param board_id: The board ID to add the image to, if it should be added. - :param image_category: The category of the image. Only the GENERAL category is added to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \ - from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \ - override or provide metadata manually. + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** """ # If the invocation inherits metadata, use that. Else, use the metadata passed in. @@ -186,14 +196,6 @@ def update( self.update = update -class LatentsKind(str, Enum): - IMAGE = "image" - NOISE = "noise" - MASK = "mask" - MASKED_IMAGE = "masked_image" - OTHER = "other" - - class LatentsInterface: def __init__( self, @@ -206,6 +208,22 @@ def save(tensor: Tensor) -> str: :param tensor: The latents tensor to save. """ + + # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. + # "mask", "noise", "masked_latents", etc. + # + # Retaining that capability in this wrapper would require either many different methods + # to save latents, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all latents. + # + # This has a very minor impact as we don't use them after a session completes. + + # Previously, invocations chose the name for their latents. This is a bit risky, so we + # will generate a name for them instead. We use a uuid to ensure the name is unique. + # + # Because the name of the latents file will includes the session and invocation IDs, + # we don't need to worry about collisions. A truncated UUIDv4 is fine. + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" services.latents.save( name=name, @@ -231,12 +249,21 @@ def __init__( services: InvocationServices, context_data: InvocationContextData, ) -> None: + # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed to work with Tensors only. We have to fudge the types here. + def save(conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. :param conditioning_data: The conditioning data to save. """ + + # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. + # + # See comment for `LatentsInterface.save` for more info about this method (it's very + # similar). + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" services.latents.save( name=name, @@ -250,9 +277,8 @@ def get(conditioning_name: str) -> ConditioningFieldData: :param conditioning_name: The name of the conditioning data to get. """ - # TODO(sm): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed as returning tensors, so we need to ignore the type here. - return services.latents.get(conditioning_name) # type: ignore [return-value] + + return services.latents.get(conditioning_name) # type: ignore [return-value] self.save = save self.get = get @@ -281,6 +307,17 @@ def load( :param model_type: The type of the model to get. :param submodel: The submodel of the model to get. """ + + # During this call, the model manager emits events with model loading status. The model + # manager itself has access to the events services, but does not have access to the + # required metadata for the events. + # + # For example, it needs access to the node's ID so that the events can be associated + # with the execution of a specific node. + # + # While this is available within the node, it's tedious to need to pass it in on every + # call. We can avoid that by wrapping the method here. + return services.model_manager.get_model( model_name, base_model, model_type, submodel, context_data=context_data ) @@ -306,8 +343,11 @@ def get() -> InvokeAIAppConfig: """ Gets the app's config. """ - # The config can be changed at runtime. We don't want nodes doing this, so we make a - # frozen copy.. + + # The config can be changed at runtime. + # + # We don't want nodes doing this, so we make a frozen copy. + config = services.configuration.get_config() frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) return frozen_config @@ -330,6 +370,12 @@ def sd_step_callback( :param intermediate_state: The intermediate state of the diffusion pipeline. :param base_model: The base model for the current denoising step. """ + + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. + stable_diffusion_step_callback( context_data=context_data, intermediate_state=intermediate_state, @@ -343,36 +389,36 @@ def sd_step_callback( class InvocationContext: """ - The invocation context provides access to various services and data about the current invocation. + The `InvocationContext` provides access to various services and data for the current invocation. """ def __init__( self, images: ImagesInterface, latents: LatentsInterface, + conditioning: ConditioningInterface, models: ModelsInterface, - config: ConfigInterface, logger: LoggerInterface, - data: InvocationContextData, + config: ConfigInterface, util: UtilInterface, - conditioning: ConditioningInterface, + data: InvocationContextData, ) -> None: self.images = images - "Provides methods to save, get and update images and their metadata." - self.logger = logger - "Provides access to the app logger." + """Provides methods to save, get and update images and their metadata.""" self.latents = latents - "Provides methods to save and get latents tensors, including image, noise, masks, and masked images." + """Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning - "Provides methods to save and get conditioning data." + """Provides methods to save and get conditioning data.""" self.models = models - "Provides methods to check if a model exists, get a model, and get a model's info." + """Provides methods to check if a model exists, get a model, and get a model's info.""" + self.logger = logger + """Provides access to the app logger.""" self.config = config - "Provides access to the app's config." - self.data = data - "Provides data about the current queue item and invocation." + """Provides access to the app's config.""" self.util = util - "Provides utility methods." + """Provides utility methods.""" + self.data = data + """Provides data about the current queue item and invocation.""" def build_invocation_context( @@ -380,8 +426,7 @@ def build_invocation_context( context_data: InvocationContextData, ) -> InvocationContext: """ - Builds the invocation context. This is a wrapper around the invocation services that provides - a more convenient (and less dangerous) interface for nodes to use. + Builds the invocation context for a specific invocation execution. :param invocation_services: The invocation services to wrap. :param invocation_context_data: The invocation context data. From 183c9c4799baa0f4c530b7190752602828ff3cb4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:34:56 +1100 Subject: [PATCH 08/67] chore: ruff --- invokeai/app/invocations/baseinvocation.py | 1 - invokeai/app/invocations/onnx.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index c4aed1fac5a..df0596c9a15 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -22,7 +22,6 @@ Input, InputFieldJSONSchemaExtra, MetadataField, - logger, ) from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.shared.invocation_context import InvocationContext diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 3f8e6669ab8..a1e318a3802 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -318,7 +318,7 @@ def dispatch_progress( name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) + # return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) # Latent to image From 5c7ed24aab1bff551aa8dde8ff9b8671b4e1e3f5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 20:16:51 +1100 Subject: [PATCH 09/67] feat(nodes): restore previous invocation context methods with deprecation warnings --- .../app/services/shared/invocation_context.py | 117 +++++++++++++++++- pyproject.toml | 1 + 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 7961c011aff..023274d49fa 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional +from deprecated import deprecated from PIL.Image import Image from pydantic import ConfigDict from torch import Tensor @@ -365,7 +366,7 @@ def sd_step_callback( The step callback emits a progress event with the current step, the total number of steps, a preview image, and some other internal metadata. - This should be called after each step of the diffusion process. + This should be called after each denoising step. :param intermediate_state: The intermediate state of the diffusion pipeline. :param base_model: The base model for the current denoising step. @@ -387,6 +388,30 @@ def sd_step_callback( self.sd_step_callback = sd_step_callback +deprecation_version = "3.7.0" +removed_version = "3.8.0" + + +def get_deprecation_reason(property_name: str, alternative: Optional[str] = None) -> str: + msg = f"{property_name} is deprecated as of v{deprecation_version}. It will be removed in v{removed_version}." + if alternative is not None: + msg += f" Use {alternative} instead." + msg += " See PLACEHOLDER_URL for details." + return msg + + +# Deprecation docstrings template. I don't think we can implement these programmatically with +# __doc__ because the IDE won't see them. + +""" +**DEPRECATED as of v3.7.0** + +PROPERTY_NAME will be removed in v3.8.0. Use ALTERNATIVE instead. See PLACEHOLDER_URL for details. + +OG_DOCSTRING +""" + + class InvocationContext: """ The `InvocationContext` provides access to various services and data for the current invocation. @@ -402,6 +427,7 @@ def __init__( config: ConfigInterface, util: UtilInterface, data: InvocationContextData, + services: InvocationServices, ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" @@ -419,6 +445,94 @@ def __init__( """Provides utility methods.""" self.data = data """Provides data about the current queue item and invocation.""" + self.__services = services + + @property + @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) + def services(self) -> InvocationServices: + """ + **DEPRECATED as of v3.7.0** + + `context.services` will be removed in v3.8.0. See PLACEHOLDER_URL for details. + + The invocation services. + """ + return self.__services + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.graph_execution_state_api`", "`context.data.session_id`"), + ) + def graph_execution_state_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.graph_execution_state_api` will be removed in v3.8.0. Use `context.data.session_id` instead. See PLACEHOLDER_URL for details. + + The ID of the session (aka graph execution state). + """ + return self.data.session_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_id`", "`context.data.queue_id`"), + ) + def queue_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_id` will be removed in v3.8.0. Use `context.data.queue_id` instead. See PLACEHOLDER_URL for details. + + The ID of the queue. + """ + return self.data.queue_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_item_id`", "`context.data.queue_item_id`"), + ) + def queue_item_id(self) -> int: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_item_id` will be removed in v3.8.0. Use `context.data.queue_item_id` instead. See PLACEHOLDER_URL for details. + + The ID of the queue item. + """ + return self.data.queue_item_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_batch_id`", "`context.data.batch_id`"), + ) + def queue_batch_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_batch_id` will be removed in v3.8.0. Use `context.data.batch_id` instead. See PLACEHOLDER_URL for details. + + The ID of the batch. + """ + return self.data.batch_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.workflow`", "`context.data.workflow`"), + ) + def workflow(self) -> Optional[WorkflowWithoutID]: + """ + **DEPRECATED as of v3.7.0** + + `context.workflow` will be removed in v3.8.0. Use `context.data.workflow` instead. See PLACEHOLDER_URL for details. + + The workflow associated with this queue item, if any. + """ + return self.data.workflow def build_invocation_context( @@ -449,6 +563,7 @@ def build_invocation_context( data=context_data, util=util, conditioning=conditioning, + services=services, ) return ctx diff --git a/pyproject.toml b/pyproject.toml index d063f1ad0ee..8d25ed20910 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dependencies = [ "albumentations", "click", "datasets", + "Deprecated", "dnspython~=2.4.0", "dynamicprompts", "easing-functions", From 3ceee2b2b248bba93a49799e2dce4ac643dcf199 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 09:37:05 +1100 Subject: [PATCH 10/67] tests: fix missing arg for InvocationContext --- tests/aa_nodes/test_graph_execution_state.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 9cc30e43e11..3577a78ae2c 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -85,7 +85,15 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( InvocationContext( - conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None + conditioning=None, + config=None, + data=None, + images=None, + latents=None, + logger=None, + models=None, + util=None, + services=None, ) ) g.complete(n.id, o) From 3de43907113c316f48fd96bcd9afd773675a4a00 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:41:25 +1100 Subject: [PATCH 11/67] feat(nodes): move `ConditioningFieldData` to `conditioning_data.py` --- invokeai/app/invocations/compel.py | 2 +- invokeai/app/invocations/fields.py | 9 +-------- invokeai/app/services/shared/invocation_context.py | 3 ++- .../stable_diffusion/diffusion/conditioning_data.py | 5 +++++ 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b4496031bc4..94caf4128d2 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,7 +5,6 @@ from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from invokeai.app.invocations.fields import ( - ConditioningFieldData, FieldDescriptions, Input, InputField, @@ -15,6 +14,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, + ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 566babbb6b7..8879f760770 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,13 +1,11 @@ -from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.util.metaenum import MetaEnum -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -544,11 +542,6 @@ def tuple(self) -> Tuple[int, int, int, int]: return (self.r, self.g, self.b, self.a) -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] - - class ConditioningField(BaseModel): """A conditioning tensor primitive value""" diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 023274d49fa..3cf3952de0e 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,7 +6,7 @@ from pydantic import ConfigDict from torch import Tensor -from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata +from invokeai.app.invocations.fields import MetadataField, WithMetadata from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -17,6 +17,7 @@ from invokeai.backend.model_management.model_manager import ModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 3e38f9f78d5..0676555f7a9 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -32,6 +32,11 @@ def to(self, device, dtype=None): return self +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + @dataclass class SDXLConditioningInfo(BasicConditioningInfo): pooled_embeds: torch.Tensor From a7e23af9c607fd2abb43891e60d5d9bfbc4d5713 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:48:33 +1100 Subject: [PATCH 12/67] feat(nodes): create invocation_api.py This is the public API for invocations. Everything a custom node might need should be re-exported from this file. --- .../controlnet_image_processors.py | 3 +- invokeai/app/invocations/facetools.py | 3 +- invokeai/app/invocations/image.py | 2 +- invokeai/app/invocations/infill.py | 4 +- invokeai/app/invocations/tiles.py | 3 +- invokeai/invocation_api/__init__.py | 109 ++++++++++++++++++ pyproject.toml | 1 + 7 files changed, 116 insertions(+), 9 deletions(-) create mode 100644 invokeai/invocation_api/__init__.py diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 3797722c93e..e993ceffde5 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,8 +25,7 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.baseinvocation import WithMetadata -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.backend.image_util.depth_anything import DepthAnythingDetector diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 2c92e28cfe0..dad63089816 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -13,11 +13,10 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, InputField, OutputField +from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 10ebd97ace3..3b8b0b4b80b 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,7 +7,6 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from invokeai.app.invocations.baseinvocation import WithMetadata from invokeai.app.invocations.fields import ( BoardField, ColorField, @@ -15,6 +14,7 @@ ImageField, Input, InputField, + WithMetadata, ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index be51c8312f9..159bdb5f7ad 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -13,8 +13,8 @@ from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, WithMetadata, invocation -from .fields import InputField +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index dd34c3dc093..0b4c472696b 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,11 +8,10 @@ BaseInvocation, BaseInvocationOutput, Classification, - WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py new file mode 100644 index 00000000000..e867ec3cc4e --- /dev/null +++ b/invokeai/invocation_api/__init__.py @@ -0,0 +1,109 @@ +""" +This file re-exports all the public API for invocations. This is the only file that should be imported by custom nodes. + +TODO(psyche): Do we want to dogfood this? +""" + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import ( + BoardField, + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + FieldKind, + ImageField, + Input, + InputField, + LatentsField, + MetadataField, + OutputField, + UIComponent, + UIType, + WithMetadata, + WithWorkflow, +) +from invokeai.app.invocations.primitives import ( + BooleanCollectionOutput, + BooleanOutput, + ColorCollectionOutput, + ColorOutput, + ConditioningCollectionOutput, + ConditioningOutput, + DenoiseMaskOutput, + FloatCollectionOutput, + FloatOutput, + ImageCollectionOutput, + ImageOutput, + IntegerCollectionOutput, + IntegerOutput, + LatentsCollectionOutput, + LatentsOutput, + StringCollectionOutput, + StringOutput, +) +from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, + ConditioningFieldData, + ExtraConditioningInfo, + SDXLConditioningInfo, +) + +__all__ = [ + # invokeai.app.invocations.baseinvocation + "BaseInvocation", + "BaseInvocationOutput", + "invocation", + "invocation_output", + # invokeai.app.services.shared.invocation_context + "InvocationContext", + # invokeai.app.invocations.fields + "BoardField", + "ColorField", + "ConditioningField", + "DenoiseMaskField", + "FieldDescriptions", + "FieldKind", + "ImageField", + "Input", + "InputField", + "LatentsField", + "MetadataField", + "OutputField", + "UIComponent", + "UIType", + "WithMetadata", + "WithWorkflow", + # invokeai.app.invocations.primitives + "BooleanCollectionOutput", + "BooleanOutput", + "ColorCollectionOutput", + "ColorOutput", + "ConditioningCollectionOutput", + "ConditioningOutput", + "DenoiseMaskOutput", + "FloatCollectionOutput", + "FloatOutput", + "ImageCollectionOutput", + "ImageOutput", + "IntegerCollectionOutput", + "IntegerOutput", + "LatentsCollectionOutput", + "LatentsOutput", + "StringCollectionOutput", + "StringOutput", + # invokeai.app.services.image_records.image_records_common + "ImageCategory", + # invokeai.backend.stable_diffusion.diffusion.conditioning_data + "BasicConditioningInfo", + "ConditioningFieldData", + "ExtraConditioningInfo", + "SDXLConditioningInfo", +] diff --git a/pyproject.toml b/pyproject.toml index 8d25ed20910..69958064c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,7 @@ version = { attr = "invokeai.version.__version__" } "invokeai.frontend.web.static*", "invokeai.configs*", "invokeai.app*", + "invokeai.invocation_api*", ] [tool.setuptools.package-data] From cc295a9f0a5e36fec2ddee522e5b990555cf803f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 16 Jan 2024 19:19:49 +1100 Subject: [PATCH 13/67] feat: tweak pyright config --- pyproject.toml | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 69958064c6d..8b28375e291 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -284,17 +284,36 @@ module = [ #=== End: MyPy [tool.pyright] -include = [ - "invokeai/app/invocations/" -] -exclude = [ - "**/node_modules", - "**/__pycache__", - "invokeai/app/invocations/onnx.py", - "invokeai/app/api/routers/models.py", - "invokeai/app/services/invocation_stats/invocation_stats_default.py", - "invokeai/app/services/model_manager/model_manager_base.py", - "invokeai/app/services/model_manager/model_manager_default.py", - "invokeai/app/services/model_records/model_records_sql.py", - "invokeai/app/util/controlnet_utils.py", -] +# Start from strict mode +typeCheckingMode = "strict" +# This errors whenever an import is missing a type stub file - way too noisy +reportMissingTypeStubs = "none" +# These are the rest of the rules enabled by strict mode - enable them @ warning +reportConstantRedefinition = "warning" +reportDeprecated = "warning" +reportDuplicateImport = "warning" +reportIncompleteStub = "warning" +reportInconsistentConstructor = "warning" +reportInvalidStubStatement = "warning" +reportMatchNotExhaustive = "warning" +reportMissingParameterType = "warning" +reportMissingTypeArgument = "warning" +reportPrivateUsage = "warning" +reportTypeCommentUsage = "warning" +reportUnknownArgumentType = "warning" +reportUnknownLambdaType = "warning" +reportUnknownMemberType = "warning" +reportUnknownParameterType = "warning" +reportUnknownVariableType = "warning" +reportUnnecessaryCast = "warning" +reportUnnecessaryComparison = "warning" +reportUnnecessaryContains = "warning" +reportUnnecessaryIsInstance = "warning" +reportUnusedClass = "warning" +reportUnusedImport = "warning" +reportUnusedFunction = "warning" +reportUnusedVariable = "warning" +reportUntypedBaseClass = "warning" +reportUntypedClassDecorator = "warning" +reportUntypedFunctionDecorator = "warning" +reportUntypedNamedTuple = "warning" From ae421fb4ab3527e40b261eab12f87750db8d08a7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:02:38 +1100 Subject: [PATCH 14/67] feat(nodes): do not freeze InvocationContextData, prevents it from being subclassesd --- invokeai/app/services/shared/invocation_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 3cf3952de0e..a849d6b17a2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -45,7 +45,7 @@ """ -@dataclass(frozen=True) +@dataclass class InvocationContextData: invocation: "BaseInvocation" """The invocation that is being executed.""" From 483bdbcb9fbc612ee0d855bbc3a6023e4535e988 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:16:35 +1100 Subject: [PATCH 15/67] fix(nodes): restore type annotations for `InvocationContext` --- docs/contributing/INVOCATIONS.md | 6 +-- invokeai/app/invocations/collections.py | 7 +-- invokeai/app/invocations/compel.py | 18 +++---- .../controlnet_image_processors.py | 5 +- invokeai/app/invocations/cv.py | 3 +- invokeai/app/invocations/facetools.py | 24 ++++----- invokeai/app/invocations/image.py | 51 ++++++++++--------- invokeai/app/invocations/infill.py | 11 ++-- invokeai/app/invocations/ip_adapter.py | 3 +- invokeai/app/invocations/latent.py | 18 +++---- invokeai/app/invocations/math.py | 21 ++++---- invokeai/app/invocations/metadata.py | 9 ++-- invokeai/app/invocations/model.py | 13 ++--- invokeai/app/invocations/noise.py | 3 +- invokeai/app/invocations/onnx.py | 8 +-- invokeai/app/invocations/param_easing.py | 5 +- invokeai/app/invocations/primitives.py | 31 +++++------ invokeai/app/invocations/prompt.py | 5 +- invokeai/app/invocations/sdxl.py | 5 +- invokeai/app/invocations/strings.py | 12 +++-- invokeai/app/invocations/t2i_adapter.py | 3 +- invokeai/app/invocations/tiles.py | 13 ++--- invokeai/app/invocations/upscale.py | 3 +- invokeai/app/services/shared/graph.py | 7 +-- tests/aa_nodes/test_nodes.py | 17 ++++--- 25 files changed, 158 insertions(+), 143 deletions(-) diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index 5d9a3690bad..ce1ee9e808a 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -174,7 +174,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context): + def invoke(self, context: InvocationContext): pass ``` @@ -203,7 +203,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pass ``` @@ -229,7 +229,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: # Load the input image as a PIL image image = context.images.get_pil(self.image.image_name) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index f5709b4ba36..e02291980f9 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -5,6 +5,7 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from .baseinvocation import BaseInvocation, invocation @@ -27,7 +28,7 @@ def stop_gt_start(cls, v: int, info: ValidationInfo): raise ValueError("stop must be greater than start") return v - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) @@ -45,7 +46,7 @@ class RangeOfSizeInvocation(BaseInvocation): size: int = InputField(default=1, gt=0, description="The number of values") step: int = InputField(default=1, description="The step of the range") - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput( collection=list(range(self.start, self.start + (self.step * self.size), self.step)) ) @@ -72,6 +73,6 @@ class RandomRangeInvocation(BaseInvocation): description="The seed for the RNG (omit for random)", ) - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 94caf4128d2..978c6dcb17f 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType @@ -12,6 +12,7 @@ UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -31,10 +32,7 @@ ) from .model import ClipField -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext - - # unconditioned: Optional[torch.Tensor] +# unconditioned: Optional[torch.Tensor] # class ConditioningAlgo(str, Enum): @@ -65,7 +63,7 @@ class CompelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) @@ -148,7 +146,7 @@ def _lora_loader(): class SDXLPromptInvocationBase: def run_clip_compel( self, - context: "InvocationContext", + context: InvocationContext, clip_field: ClipField, prompt: str, get_pooled: bool, @@ -288,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -373,7 +371,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -418,7 +416,7 @@ class ClipSkipInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) - def invoke(self, context) -> ClipSkipInvocationOutput: + def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers return ClipSkipInvocationOutput( clip=self.clip, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e993ceffde5..f8bdf14117c 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -28,6 +28,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector from invokeai.backend.model_management.models.base import BaseModelType @@ -119,7 +120,7 @@ def validate_begin_end_step_percent(self) -> "ControlNetInvocation": validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> ControlOutput: + def invoke(self, context: InvocationContext) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, @@ -143,7 +144,7 @@ def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 375b18f9c58..1ebabf5e064 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -7,6 +7,7 @@ from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata @@ -19,7 +20,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) mask = context.images.get_pil(self.mask.image_name) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index dad63089816..a1702d6517c 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -1,7 +1,7 @@ import math import re from pathlib import Path -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import Optional, TypedDict import cv2 import numpy as np @@ -19,9 +19,7 @@ from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory - -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -176,7 +174,7 @@ def prepare_faces_list( def generate_face_box_mask( - context: "InvocationContext", + context: InvocationContext, minimum_confidence: float, x_offset: float, y_offset: float, @@ -275,7 +273,7 @@ def generate_face_box_mask( def extract_face( - context: "InvocationContext", + context: InvocationContext, image: ImageType, face: FaceResultData, padding: int, @@ -356,7 +354,7 @@ def extract_face( def get_faces_list( - context: "InvocationContext", + context: InvocationContext, image: ImageType, should_chunk: bool, minimum_confidence: float, @@ -458,7 +456,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]: + def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]: all_faces = get_faces_list( context=context, image=image, @@ -485,7 +483,7 @@ def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[Ex return face_data - def invoke(self, context) -> FaceOffOutput: + def invoke(self, context: InvocationContext) -> FaceOffOutput: image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) @@ -543,7 +541,7 @@ def validate_comma_separated_ints(cls, v) -> str: raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")') return v - def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult: + def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult: all_faces = get_faces_list( context=context, image=image, @@ -600,7 +598,7 @@ def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskRe mask=mask_pil, ) - def invoke(self, context) -> FaceMaskOutput: + def invoke(self, context: InvocationContext) -> FaceMaskOutput: image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) @@ -633,7 +631,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType: + def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType: image = image.copy() all_faces = get_faces_list( @@ -674,7 +672,7 @@ def faceidentifier(self, context: "InvocationContext", image: ImageType) -> Imag return image - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 3b8b0b4b80b..7b74e4d96d4 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -18,6 +18,7 @@ ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker @@ -34,7 +35,7 @@ class ShowImageInvocation(BaseInvocation): image: ImageField = InputField(description="The image to show") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image.show() @@ -62,7 +63,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) image_dto = context.images.save(image=image) @@ -86,7 +87,7 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) @@ -125,7 +126,7 @@ class CenterPadCropInvocation(BaseInvocation): description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions @@ -161,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): y: int = InputField(default=0, description="The top y coordinate at which to paste the image") crop: bool = InputField(default=False, description="Crop to base image dimensions") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.images.get_pil(self.base_image.image_name) image = context.images.get_pil(self.image.image_name) mask = None @@ -201,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] @@ -226,7 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata): image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.images.get_pil(self.image1.image_name) image2 = context.images.get_pil(self.image2.image_name) @@ -253,7 +254,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) @@ -279,7 +280,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) @@ -304,7 +305,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): # Metadata blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) blur = ( @@ -338,7 +339,7 @@ def pil_from_array(self, arr): def array_from_pil(self, img): return numpy.array(img) / 255 - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) mode = image.mode @@ -401,7 +402,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height to resize to (px)") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -434,7 +435,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): ) resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -465,7 +466,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 @@ -492,7 +493,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) @@ -517,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) logger = context.logger @@ -553,7 +554,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) image_dto = context.images.save(image=new_image) @@ -579,7 +580,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): description="Second threshold for the hysteresis procedure in Canny edge detection" ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) @@ -613,7 +614,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask1 = context.images.get_pil(self.mask1.image_name).convert("L") mask2 = context.images.get_pil(self.mask2.image_name).convert("L") @@ -642,7 +643,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask_blur_radius: float = InputField(default=8, description="Mask blur radius") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_init_mask = None if self.mask is not None: pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") @@ -741,7 +742,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space @@ -831,7 +832,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): channel: COLOR_CHANNELS = InputField(description="Which channel to adjust") offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple @@ -888,7 +889,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.") invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple @@ -931,7 +932,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) @@ -953,7 +954,7 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image_dto = context.images.get_dto(self.image.image_name) image_dto = context.images.update( diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 159bdb5f7ad..b007edd9e42 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -8,6 +8,7 @@ from invokeai.app.invocations.fields import ColorField, ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA @@ -129,7 +130,7 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): description="The color to use to infill", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) @@ -155,7 +156,7 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): description="The seed to use for tile generation (omit for random)", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) @@ -176,7 +177,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -213,7 +214,7 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to infill") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) @@ -229,7 +230,7 @@ class CV2InfillInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to infill") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index b836be04b58..845fcfa2848 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -13,6 +13,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id @@ -92,7 +93,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> IPAdapterOutput: + def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_info( self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 0127a6521e1..2cc84f80a73 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import TYPE_CHECKING, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union import einops import numpy as np @@ -42,6 +42,7 @@ LatentsOutput, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings @@ -70,9 +71,6 @@ from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext - if choose_torch_device() == torch.device("mps"): from torch import mps @@ -177,7 +175,7 @@ def invoke(self, context) -> DenoiseMaskOutput: def get_scheduler( - context: "InvocationContext", + context: InvocationContext, scheduler_info: ModelInfo, scheduler_name: str, seed: int, @@ -300,7 +298,7 @@ def ge_one(cls, v): def get_conditioning_data( self, - context: "InvocationContext", + context: InvocationContext, scheduler, unet, seed, @@ -369,7 +367,7 @@ def __init__(self): def prep_control_data( self, - context: "InvocationContext", + context: InvocationContext, control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, @@ -442,7 +440,7 @@ def prep_control_data( def prep_ip_adapter_data( self, - context: "InvocationContext", + context: InvocationContext, ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], conditioning_data: ConditioningData, exit_stack: ExitStack, @@ -509,7 +507,7 @@ def prep_ip_adapter_data( def run_t2i_adapters( self, - context: "InvocationContext", + context: InvocationContext, t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, @@ -618,7 +616,7 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context: "InvocationContext", latents): + def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index d2dbf049816..83a092be69e 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -7,6 +7,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation @@ -18,7 +19,7 @@ class AddInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a + self.b) @@ -29,7 +30,7 @@ class SubtractInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a - self.b) @@ -40,7 +41,7 @@ class MultiplyInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a * self.b) @@ -51,7 +52,7 @@ class DivideInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=int(self.a / self.b)) @@ -69,7 +70,7 @@ class RandomIntInvocation(BaseInvocation): low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) @@ -88,7 +89,7 @@ class RandomFloatInvocation(BaseInvocation): high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: random_float = np.random.uniform(self.low, self.high) rounded_float = round(random_float, self.decimals) return FloatOutput(value=rounded_float) @@ -110,7 +111,7 @@ class FloatToIntegerInvocation(BaseInvocation): default="Nearest", description="The method to use for rounding" ) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: if self.method == "Nearest": return IntegerOutput(value=round(self.value / self.multiple) * self.multiple) elif self.method == "Floor": @@ -128,7 +129,7 @@ class RoundInvocation(BaseInvocation): value: float = InputField(default=0, description="The float value") decimals: int = InputField(default=0, description="The number of decimal places") - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=round(self.value, self.decimals)) @@ -196,7 +197,7 @@ def no_unrepresentable_results(cls, v: int, info: ValidationInfo): raise ValueError("Result of exponentiation is not an integer") return v - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return IntegerOutput(value=self.a + self.b) @@ -270,7 +271,7 @@ def no_unrepresentable_results(cls, v: float, info: ValidationInfo): raise ValueError("Root operation resulted in a complex number") return v - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return FloatOutput(value=self.a + self.b) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 9d74abd8c12..58edfab711a 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -20,6 +20,7 @@ from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext from ...version import __version__ @@ -64,7 +65,7 @@ class MetadataItemInvocation(BaseInvocation): label: str = InputField(description=FieldDescriptions.metadata_item_label) value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any) - def invoke(self, context) -> MetadataItemOutput: + def invoke(self, context: InvocationContext) -> MetadataItemOutput: return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value)) @@ -81,7 +82,7 @@ class MetadataInvocation(BaseInvocation): description=FieldDescriptions.metadata_item_polymorphic ) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: if isinstance(self.items, MetadataItemField): # single metadata item data = {self.items.label: self.items.value} @@ -100,7 +101,7 @@ class MergeMetadataInvocation(BaseInvocation): collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: data = {} for item in self.collection: data.update(item.model_dump()) @@ -218,7 +219,7 @@ class CoreMetadataInvocation(BaseInvocation): description="The start value used for refiner denoising", ) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" return MetadataOutput( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index f81e559e446..6a1fd6d36bc 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig from ...backend.model_management import BaseModelType, ModelType, SubModelType @@ -109,7 +110,7 @@ class MainModelLoaderInvocation(BaseInvocation): model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - def invoke(self, context) -> ModelLoaderOutput: + def invoke(self, context: InvocationContext) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -221,7 +222,7 @@ class LoraLoaderInvocation(BaseInvocation): title="CLIP", ) - def invoke(self, context) -> LoraLoaderOutput: + def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") @@ -310,7 +311,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): title="CLIP 2", ) - def invoke(self, context) -> SDXLLoraLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") @@ -393,7 +394,7 @@ class VaeLoaderInvocation(BaseInvocation): title="VAE", ) - def invoke(self, context) -> VAEOutput: + def invoke(self, context: InvocationContext) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae @@ -448,7 +449,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context) -> SeamlessModeOutput: + def invoke(self, context: InvocationContext) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) @@ -484,6 +485,6 @@ class FreeUInvocation(BaseInvocation): s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) - def invoke(self, context) -> UNetOutput: + def invoke(self, context: InvocationContext) -> UNetOutput: self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 41641152f04..78f13cc52d1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,6 +5,7 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype @@ -112,7 +113,7 @@ def modulo_seed(cls, v): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) - def invoke(self, context) -> NoiseOutput: + def invoke(self, context: InvocationContext) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index a1e318a3802..e7b4d3d9fc5 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.model_dump(), ) @@ -201,7 +201,7 @@ def ge_one(cls, v): # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): ) # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: @@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation): description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel ) - def invoke(self, context) -> ONNXModelLoaderOutput: + def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.ONNX diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index bf59e87d270..6845637de92 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -40,6 +40,7 @@ from matplotlib.ticker import MaxNLocator from invokeai.app.invocations.primitives import FloatCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -62,7 +63,7 @@ class FloatLinearRangeInvocation(BaseInvocation): description="number of values to interpolate over (including start and stop)", ) - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) @@ -130,7 +131,7 @@ class StepParamEasingInvocation(BaseInvocation): # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = InputField(default=False, description="show easing plot") - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index ee04345eed8..c90d3230b2b 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -17,6 +17,7 @@ UIComponent, ) from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import ( BaseInvocation, @@ -59,7 +60,7 @@ class BooleanInvocation(BaseInvocation): value: bool = InputField(default=False, description="The boolean value") - def invoke(self, context) -> BooleanOutput: + def invoke(self, context: InvocationContext) -> BooleanOutput: return BooleanOutput(value=self.value) @@ -75,7 +76,7 @@ class BooleanCollectionInvocation(BaseInvocation): collection: list[bool] = InputField(default=[], description="The collection of boolean values") - def invoke(self, context) -> BooleanCollectionOutput: + def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -108,7 +109,7 @@ class IntegerInvocation(BaseInvocation): value: int = InputField(default=0, description="The integer value") - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.value) @@ -124,7 +125,7 @@ class IntegerCollectionInvocation(BaseInvocation): collection: list[int] = InputField(default=[], description="The collection of integer values") - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -155,7 +156,7 @@ class FloatInvocation(BaseInvocation): value: float = InputField(default=0.0, description="The float value") - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=self.value) @@ -171,7 +172,7 @@ class FloatCollectionInvocation(BaseInvocation): collection: list[float] = InputField(default=[], description="The collection of float values") - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -202,7 +203,7 @@ class StringInvocation(BaseInvocation): value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=self.value) @@ -218,7 +219,7 @@ class StringCollectionInvocation(BaseInvocation): collection: list[str] = InputField(default=[], description="The collection of string values") - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -261,7 +262,7 @@ class ImageInvocation( image: ImageField = InputField(description="The image to load") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) return ImageOutput( @@ -283,7 +284,7 @@ class ImageCollectionInvocation(BaseInvocation): collection: list[ImageField] = InputField(description="The collection of image values") - def invoke(self, context) -> ImageCollectionOutput: + def invoke(self, context: InvocationContext) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -346,7 +347,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) @@ -366,7 +367,7 @@ class LatentsCollectionInvocation(BaseInvocation): description="The collection of latents tensors", ) - def invoke(self, context) -> LatentsCollectionOutput: + def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) @@ -397,7 +398,7 @@ class ColorInvocation(BaseInvocation): color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") - def invoke(self, context) -> ColorOutput: + def invoke(self, context: InvocationContext) -> ColorOutput: return ColorOutput(color=self.color) @@ -438,7 +439,7 @@ class ConditioningInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: return ConditioningOutput(conditioning=self.conditioning) @@ -457,7 +458,7 @@ class ConditioningCollectionInvocation(BaseInvocation): description="The collection of conditioning tensors", ) - def invoke(self, context) -> ConditioningCollectionOutput: + def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: return ConditioningCollectionOutput(collection=self.collection) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 4f5ef43a568..234743a0035 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -6,6 +6,7 @@ from pydantic import field_validator from invokeai.app.invocations.primitives import StringCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField, UIComponent @@ -29,7 +30,7 @@ class DynamicPromptInvocation(BaseInvocation): max_prompts: int = InputField(default=1, description="The number of prompts to generate") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -91,7 +92,7 @@ def promptsFromFile( break return prompts - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 75a526cfff6..8d51674a046 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,4 +1,5 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( @@ -38,7 +39,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context) -> SDXLModelLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -127,7 +128,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context) -> SDXLRefinerModelLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index a4c92d9de56..182c976cd77 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -2,6 +2,8 @@ import re +from invokeai.app.services.shared.invocation_context import InvocationContext + from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -32,7 +34,7 @@ class StringSplitNegInvocation(BaseInvocation): string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringPosNegOutput: + def invoke(self, context: InvocationContext) -> StringPosNegOutput: p_string = "" n_string = "" brackets_depth = 0 @@ -76,7 +78,7 @@ class StringSplitInvocation(BaseInvocation): default="", description="Delimiter to spilt with. blank will split on the first whitespace" ) - def invoke(self, context) -> String2Output: + def invoke(self, context: InvocationContext) -> String2Output: result = self.string.split(self.delimiter, 1) if len(result) == 2: part1, part2 = result @@ -94,7 +96,7 @@ class StringJoinInvocation(BaseInvocation): string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) @@ -106,7 +108,7 @@ class StringJoinThreeInvocation(BaseInvocation): string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or ""))) @@ -125,7 +127,7 @@ class StringReplaceInvocation(BaseInvocation): default=False, description="Use search string as a regex expression (non regex is case insensitive)" ) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: pattern = self.search_string or "" new_string = self.string or "" if len(pattern) > 0: diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 74a098a501c..0f4fe66ada1 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -11,6 +11,7 @@ from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_management.models.base import BaseModelType @@ -89,7 +90,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> T2IAdapterOutput: + def invoke(self, context: InvocationContext) -> T2IAdapterOutput: return T2IAdapterOutput( t2i_adapter=T2IAdapterField( image=self.image, diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index 0b4c472696b..19ece423761 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -13,6 +13,7 @@ ) from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -56,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation): description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", ) - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_with_overlap( image_height=self.image_height, image_width=self.image_width, @@ -99,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation): description="The overlap, in pixels, between adjacent tiles.", ) - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_even_split( image_height=self.image_height, image_width=self.image_width, @@ -129,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_min_overlap( image_height=self.image_height, image_width=self.image_width, @@ -174,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation): tile: Tile = InputField(description="The tile to split into properties.") - def invoke(self, context) -> TileToPropertiesOutput: + def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: return TileToPropertiesOutput( coords_left=self.tile.coords.left, coords_right=self.tile.coords.right, @@ -211,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation): image: ImageField = InputField(description="The tile image.") tile: Tile = InputField(description="The tile properties.") - def invoke(self, context) -> PairTileImageOutput: + def invoke(self, context: InvocationContext) -> PairTileImageOutput: return PairTileImageOutput( tile_with_image=TileWithImage( tile=self.tile, @@ -247,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: images = [twi.image for twi in self.tiles_with_images] tiles = [twi.tile for twi in self.tiles_with_images] diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index ef174809860..71ef7ca3aa0 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -10,6 +10,7 @@ from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device @@ -42,7 +43,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) models_path = context.config.get().models_path diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index c0699eb96bb..3df230f5ee7 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -17,6 +17,7 @@ invocation_output, ) from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" @@ -201,7 +202,7 @@ class GraphInvocation(BaseInvocation): # TODO: figure out how to create a default here graph: "Graph" = InputField(description="The graph to run", default=None) - def invoke(self, context) -> GraphInvocationOutput: + def invoke(self, context: InvocationContext) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() @@ -227,7 +228,7 @@ class IterateInvocation(BaseInvocation): ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) - def invoke(self, context) -> IterateInvocationOutput: + def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @@ -254,7 +255,7 @@ class CollectInvocation(BaseInvocation): description="The collection, will be provided on execution", default=[], ui_hidden=True ) - def invoke(self, context) -> CollectInvocationOutput: + def invoke(self, context: InvocationContext) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index 559457c0e11..aab3d9c7b4b 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -8,6 +8,7 @@ ) from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField +from invokeai.app.services.shared.invocation_context import InvocationContext # Define test invocations before importing anything that uses invocations @@ -20,7 +21,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocation(BaseInvocation): collection: list[ImageField] = InputField(default=[]) - def invoke(self, context) -> ListPassThroughInvocationOutput: + def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -33,13 +34,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocation(BaseInvocation): prompt: str = InputField(default="") - def invoke(self, context) -> PromptTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @invocation("test_error", version="1.0.0") class ErrorInvocation(BaseInvocation): - def invoke(self, context) -> PromptTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") @@ -53,7 +54,7 @@ class TextToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") prompt2: str = InputField(default="") - def invoke(self, context) -> ImageTestInvocationOutput: + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -62,7 +63,7 @@ class ImageToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") image: Union[ImageField, None] = InputField(default=None) - def invoke(self, context) -> ImageTestInvocationOutput: + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -75,7 +76,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocation(BaseInvocation): collection: list[str] = InputField() - def invoke(self, context) -> PromptCollectionTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -88,7 +89,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocation(BaseInvocation): value: Any = InputField(default=None) - def invoke(self, context) -> AnyTypeTestInvocationOutput: + def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -96,7 +97,7 @@ def invoke(self, context) -> AnyTypeTestInvocationOutput: class PolymorphicStringTestInvocation(BaseInvocation): value: Union[str, list[str]] = InputField(default="") - def invoke(self, context) -> PromptCollectionTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str): return PromptCollectionTestInvocationOutput(collection=[self.value]) return PromptCollectionTestInvocationOutput(collection=self.value) From e52434cb9985ee58d6efcbe4c0b8df5df3b4af55 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:40:49 +1100 Subject: [PATCH 16/67] feat(nodes): add boards interface to invocation context --- .../app/services/shared/invocation_context.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index a849d6b17a2..cbcaa6a5489 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -7,6 +7,7 @@ from torch import Tensor from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -63,6 +64,54 @@ class InvocationContextData: """The workflow associated with this queue item, if any.""" +class BoardsInterface: + def __init__(self, services: InvocationServices) -> None: + def create(board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return services.boards.create(board_name) + + def get_dto(board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return services.boards.get_dto(board_id) + + def get_all() -> list[BoardDTO]: + """ + Gets all boards. + """ + return services.boards.get_all() + + def add_image_to_board(board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return services.board_images.get_all_board_image_names_for_board(board_id) + + self.create = create + self.get_dto = get_dto + self.get_all = get_all + self.add_image_to_board = add_image_to_board + self.get_all_image_names_for_board = get_all_image_names_for_board + + class LoggerInterface: def __init__(self, services: InvocationServices) -> None: def debug(message: str) -> None: @@ -427,6 +476,7 @@ def __init__( logger: LoggerInterface, config: ConfigInterface, util: UtilInterface, + boards: BoardsInterface, data: InvocationContextData, services: InvocationServices, ) -> None: @@ -444,6 +494,8 @@ def __init__( """Provides access to the app's config.""" self.util = util """Provides utility methods.""" + self.boards = boards + """Provides methods to interact with boards.""" self.data = data """Provides data about the current queue item and invocation.""" self.__services = services @@ -554,6 +606,7 @@ def build_invocation_context( config = ConfigInterface(services=services) util = UtilInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data) + boards = BoardsInterface(services=services) ctx = InvocationContext( images=images, @@ -565,6 +618,7 @@ def build_invocation_context( util=util, conditioning=conditioning, services=services, + boards=boards, ) return ctx From bf48e8a03a51253922201bf1b892ee90c3c332e9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:48:32 +1100 Subject: [PATCH 17/67] feat(nodes): export more things from `invocation_api" --- invokeai/invocation_api/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index e867ec3cc4e..e80bc26a003 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -47,8 +47,14 @@ StringCollectionOutput, StringOutput, ) +from invokeai.app.services.boards.boards_common import BoardDTO +from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -101,9 +107,23 @@ "StringOutput", # invokeai.app.services.image_records.image_records_common "ImageCategory", + # invokeai.app.services.boards.boards_common + "BoardDTO", # invokeai.backend.stable_diffusion.diffusion.conditioning_data "BasicConditioningInfo", "ConditioningFieldData", "ExtraConditioningInfo", "SDXLConditioningInfo", + # invokeai.backend.stable_diffusion.diffusers_pipeline + "PipelineIntermediateState", + # invokeai.app.services.workflow_records.workflow_records_common + "WorkflowWithoutID", + # invokeai.app.services.config.config_default + "InvokeAIAppConfig", + # invokeai.backend.model_management.model_manager + "ModelInfo", + # invokeai.backend.model_management.models.base + "BaseModelType", + "ModelType", + "SubModelType", ] From 3f5ab02da977e3c29c9311ce24bce7e170bd9779 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:06:01 +1100 Subject: [PATCH 18/67] chore(nodes): add comments for ConfigInterface --- invokeai/app/services/shared/invocation_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cbcaa6a5489..cb989cb15e0 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -392,7 +392,7 @@ class ConfigInterface: def __init__(self, services: InvocationServices) -> None: def get() -> InvokeAIAppConfig: """ - Gets the app's config. + Gets the app's config. The config is read-only; attempts to mutate it will raise an error. """ # The config can be changed at runtime. @@ -400,6 +400,7 @@ def get() -> InvokeAIAppConfig: # We don't want nodes doing this, so we make a frozen copy. config = services.configuration.get_config() + # TODO(psyche): If config cannot be changed at runtime, should we cache this? frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) return frozen_config From 9c1e52b1ef6a0606594848f2f65457c10a4cf1e5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 6 Feb 2024 00:37:18 +1100 Subject: [PATCH 19/67] tests(nodes): fix mock InvocationContext --- tests/aa_nodes/test_graph_execution_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 3577a78ae2c..1612cbe7198 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -93,6 +93,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B logger=None, models=None, util=None, + boards=None, services=None, ) ) From afbe889d357392f7109f67fed0cc1d99c52290cc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 6 Feb 2024 10:22:58 +1100 Subject: [PATCH 20/67] fix(nodes): restore missing context type annotations --- invokeai/app/invocations/latent.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 2cc84f80a73..5e36e73ec8f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -106,7 +106,7 @@ class SchedulerInvocation(BaseInvocation): ui_type=UIType.Scheduler, ) - def invoke(self, context) -> SchedulerOutput: + def invoke(self, context: InvocationContext) -> SchedulerOutput: return SchedulerOutput(scheduler=self.scheduler) @@ -141,7 +141,7 @@ def prep_mask_tensor(self, mask_image): return mask_tensor @torch.no_grad() - def invoke(self, context) -> DenoiseMaskOutput: + def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: image = context.images.get_pil(self.image.image_name) image = image_resized_to_grid_as_tensor(image.convert("RGB")) @@ -630,7 +630,7 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): return 1 - mask, masked_latents @torch.no_grad() - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None @@ -777,7 +777,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.latents.get(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -868,7 +868,7 @@ class ResizeLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) # TODO: @@ -909,7 +909,7 @@ class ScaleLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) # TODO: @@ -998,7 +998,7 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): return latents @torch.no_grad() - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -1046,7 +1046,7 @@ class BlendLatentsInvocation(BaseInvocation): ) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents_a = context.latents.get(self.latents_a.latents_name) latents_b = context.latents.get(self.latents_b.latents_name) @@ -1147,7 +1147,7 @@ class CropLatentsCoreInvocation(BaseInvocation): description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", ) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR From b4c774896aef541e2221dc6f185729234b63099c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:24:05 +1100 Subject: [PATCH 21/67] feat(nodes): do not hide `services` in invocation context interfaces --- .../app/services/shared/invocation_context.py | 675 ++++++++---------- 1 file changed, 317 insertions(+), 358 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cb989cb15e0..54c50bcf76b 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -64,379 +64,338 @@ class InvocationContextData: """The workflow associated with this queue item, if any.""" -class BoardsInterface: - def __init__(self, services: InvocationServices) -> None: - def create(board_name: str) -> BoardDTO: - """ - Creates a board. - - :param board_name: The name of the board to create. - """ - return services.boards.create(board_name) - - def get_dto(board_id: str) -> BoardDTO: - """ - Gets a board DTO. - - :param board_id: The ID of the board to get. - """ - return services.boards.get_dto(board_id) - - def get_all() -> list[BoardDTO]: - """ - Gets all boards. - """ - return services.boards.get_all() - - def add_image_to_board(board_id: str, image_name: str) -> None: - """ - Adds an image to a board. - - :param board_id: The ID of the board to add the image to. - :param image_name: The name of the image to add to the board. - """ - services.board_images.add_image_to_board(board_id, image_name) - - def get_all_image_names_for_board(board_id: str) -> list[str]: - """ - Gets all image names for a board. - - :param board_id: The ID of the board to get the image names for. - """ - return services.board_images.get_all_board_image_names_for_board(board_id) - - self.create = create - self.get_dto = get_dto - self.get_all = get_all - self.add_image_to_board = add_image_to_board - self.get_all_image_names_for_board = get_all_image_names_for_board - - -class LoggerInterface: - def __init__(self, services: InvocationServices) -> None: - def debug(message: str) -> None: - """ - Logs a debug message. - - :param message: The message to log. - """ - services.logger.debug(message) - - def info(message: str) -> None: - """ - Logs an info message. - - :param message: The message to log. - """ - services.logger.info(message) - - def warning(message: str) -> None: - """ - Logs a warning message. - - :param message: The message to log. - """ - services.logger.warning(message) - - def error(message: str) -> None: - """ - Logs an error message. - - :param message: The message to log. - """ - services.logger.error(message) - - self.debug = debug - self.info = info - self.warning = warning - self.error = error - - -class ImagesInterface: +class InvocationContextInterface: def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def save( - image: Image, - board_id: Optional[str] = None, - image_category: ImageCategory = ImageCategory.GENERAL, - metadata: Optional[MetadataField] = None, - ) -> ImageDTO: - """ - Saves an image, returning its DTO. - - If the current queue item has a workflow or metadata, it is automatically saved with the image. - - :param image: The image to save, as a PIL image. - :param board_id: The board ID to add the image to, if it should be added. - :param image_category: The category of the image. Only the GENERAL category is added \ - to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the \ - invocation inherits from `WithMetadata`, that metadata will be used automatically. \ - **Use this only if you want to override or provide metadata manually!** - """ - - # If the invocation inherits metadata, use that. Else, use the metadata passed in. - metadata_ = ( - context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata - ) - - return services.images.create( - image=image, - is_intermediate=context_data.invocation.is_intermediate, - image_category=image_category, - board_id=board_id, - metadata=metadata_, - image_origin=ResourceOrigin.INTERNAL, - workflow=context_data.workflow, - session_id=context_data.session_id, - node_id=context_data.invocation.id, - ) - - def get_pil(image_name: str) -> Image: - """ - Gets an image as a PIL Image object. - - :param image_name: The name of the image to get. - """ - return services.images.get_pil_image(image_name) - - def get_metadata(image_name: str) -> Optional[MetadataField]: - """ - Gets an image's metadata, if it has any. - - :param image_name: The name of the image to get the metadata for. - """ - return services.images.get_metadata(image_name) - - def get_dto(image_name: str) -> ImageDTO: - """ - Gets an image as an ImageDTO object. - - :param image_name: The name of the image to get. - """ - return services.images.get_dto(image_name) - - def update( - image_name: str, - board_id: Optional[str] = None, - is_intermediate: Optional[bool] = False, - ) -> ImageDTO: - """ - Updates an image, returning its updated DTO. - - It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - - If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to - get the updated image. - - :param image_name: The name of the image to update. - :param board_id: The board ID to add the image to, if it should be added. - :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. - """ - if is_intermediate is not None: - services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) - if board_id is None: - services.board_images.remove_image_from_board(image_name) - else: - services.board_images.add_image_to_board(image_name, board_id) - return services.images.get_dto(image_name) - - self.save = save - self.get_pil = get_pil - self.get_metadata = get_metadata - self.get_dto = get_dto - self.update = update - - -class LatentsInterface: - def __init__( + self._services = services + self._context_data = context_data + + +class BoardsInterface(InvocationContextInterface): + def create(self, board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return self._services.boards.create(board_name) + + def get_dto(self, board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return self._services.boards.get_dto(board_id) + + def get_all(self) -> list[BoardDTO]: + """ + Gets all boards. + """ + return self._services.boards.get_all() + + def add_image_to_board(self, board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + return self._services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(self, board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return self._services.board_images.get_all_board_image_names_for_board(board_id) + + +class LoggerInterface(InvocationContextInterface): + def debug(self, message: str) -> None: + """ + Logs a debug message. + + :param message: The message to log. + """ + self._services.logger.debug(message) + + def info(self, message: str) -> None: + """ + Logs an info message. + + :param message: The message to log. + """ + self._services.logger.info(message) + + def warning(self, message: str) -> None: + """ + Logs a warning message. + + :param message: The message to log. + """ + self._services.logger.warning(message) + + def error(self, message: str) -> None: + """ + Logs an error message. + + :param message: The message to log. + """ + self._services.logger.error(message) + + +class ImagesInterface(InvocationContextInterface): + def save( self, - services: InvocationServices, - context_data: InvocationContextData, - ) -> None: - def save(tensor: Tensor) -> str: - """ - Saves a latents tensor, returning its name. - - :param tensor: The latents tensor to save. - """ - - # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. - # "mask", "noise", "masked_latents", etc. - # - # Retaining that capability in this wrapper would require either many different methods - # to save latents, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all latents. - # - # This has a very minor impact as we don't use them after a session completes. - - # Previously, invocations chose the name for their latents. This is a bit risky, so we - # will generate a name for them instead. We use a uuid to ensure the name is unique. - # - # Because the name of the latents file will includes the session and invocation IDs, - # we don't need to worry about collisions. A truncated UUIDv4 is fine. - - name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" - services.latents.save( - name=name, - data=tensor, - ) - return name - - def get(latents_name: str) -> Tensor: - """ - Gets a latents tensor by name. - - :param latents_name: The name of the latents tensor to get. - """ - return services.latents.get(latents_name) - - self.save = save - self.get = get - - -class ConditioningInterface: - def __init__( + image: Image, + board_id: Optional[str] = None, + image_category: ImageCategory = ImageCategory.GENERAL, + metadata: Optional[MetadataField] = None, + ) -> ImageDTO: + """ + Saves an image, returning its DTO. + + If the current queue item has a workflow or metadata, it is automatically saved with the image. + + :param image: The image to save, as a PIL image. + :param board_id: The board ID to add the image to, if it should be added. + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** + """ + + # If the invocation inherits metadata, use that. Else, use the metadata passed in. + metadata_ = ( + self._context_data.invocation.metadata + if isinstance(self._context_data.invocation, WithMetadata) + else metadata + ) + + return self._services.images.create( + image=image, + is_intermediate=self._context_data.invocation.is_intermediate, + image_category=image_category, + board_id=board_id, + metadata=metadata_, + image_origin=ResourceOrigin.INTERNAL, + workflow=self._context_data.workflow, + session_id=self._context_data.session_id, + node_id=self._context_data.invocation.id, + ) + + def get_pil(self, image_name: str) -> Image: + """ + Gets an image as a PIL Image object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_pil_image(image_name) + + def get_metadata(self, image_name: str) -> Optional[MetadataField]: + """ + Gets an image's metadata, if it has any. + + :param image_name: The name of the image to get the metadata for. + """ + return self._services.images.get_metadata(image_name) + + def get_dto(self, image_name: str) -> ImageDTO: + """ + Gets an image as an ImageDTO object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_dto(image_name) + + def update( self, - services: InvocationServices, - context_data: InvocationContextData, - ) -> None: - # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed to work with Tensors only. We have to fudge the types here. + image_name: str, + board_id: Optional[str] = None, + is_intermediate: Optional[bool] = False, + ) -> ImageDTO: + """ + Updates an image, returning its updated DTO. + + It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - def save(conditioning_data: ConditioningFieldData) -> str: - """ - Saves a conditioning data object, returning its name. + If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to + get the updated image. - :param conditioning_data: The conditioning data to save. - """ + :param image_name: The name of the image to update. + :param board_id: The board ID to add the image to, if it should be added. + :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. + """ + if is_intermediate is not None: + self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) + if board_id is None: + self._services.board_images.remove_image_from_board(image_name) + else: + self._services.board_images.add_image_to_board(image_name, board_id) + return self._services.images.get_dto(image_name) - # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. - # - # See comment for `LatentsInterface.save` for more info about this method (it's very - # similar). - name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" - services.latents.save( - name=name, - data=conditioning_data, # type: ignore [arg-type] - ) - return name +class LatentsInterface(InvocationContextInterface): + def save(self, tensor: Tensor) -> str: + """ + Saves a latents tensor, returning its name. - def get(conditioning_name: str) -> ConditioningFieldData: - """ - Gets conditioning data by name. + :param tensor: The latents tensor to save. + """ - :param conditioning_name: The name of the conditioning data to get. - """ + # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. + # "mask", "noise", "masked_latents", etc. + # + # Retaining that capability in this wrapper would require either many different methods + # to save latents, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all latents. + # + # This has a very minor impact as we don't use them after a session completes. + + # Previously, invocations chose the name for their latents. This is a bit risky, so we + # will generate a name for them instead. We use a uuid to ensure the name is unique. + # + # Because the name of the latents file will includes the session and invocation IDs, + # we don't need to worry about collisions. A truncated UUIDv4 is fine. + + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" + self._services.latents.save( + name=name, + data=tensor, + ) + return name + + def get(self, latents_name: str) -> Tensor: + """ + Gets a latents tensor by name. - return services.latents.get(conditioning_name) # type: ignore [return-value] + :param latents_name: The name of the latents tensor to get. + """ + return self._services.latents.get(latents_name) - self.save = save - self.get = get +class ConditioningInterface(InvocationContextInterface): + # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed to work with Tensors only. We have to fudge the types here. + def save(self, conditioning_data: ConditioningFieldData) -> str: + """ + Saves a conditioning data object, returning its name. -class ModelsInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: - """ - Checks if a model exists. - - :param model_name: The name of the model to check. - :param base_model: The base model of the model to check. - :param model_type: The type of the model to check. - """ - return services.model_manager.model_exists(model_name, base_model, model_type) - - def load( - model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> ModelInfo: - """ - Loads a model, returning its `ModelInfo` object. - - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - :param submodel: The submodel of the model to get. - """ - - # During this call, the model manager emits events with model loading status. The model - # manager itself has access to the events services, but does not have access to the - # required metadata for the events. - # - # For example, it needs access to the node's ID so that the events can be associated - # with the execution of a specific node. - # - # While this is available within the node, it's tedious to need to pass it in on every - # call. We can avoid that by wrapping the method here. - - return services.model_manager.get_model( - model_name, base_model, model_type, submodel, context_data=context_data - ) - - def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Gets a model's info, an dict-like object. - - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - """ - return services.model_manager.model_info(model_name, base_model, model_type) - - self.exists = exists - self.load = load - self.get_info = get_info - - -class ConfigInterface: - def __init__(self, services: InvocationServices) -> None: - def get() -> InvokeAIAppConfig: - """ - Gets the app's config. The config is read-only; attempts to mutate it will raise an error. - """ - - # The config can be changed at runtime. - # - # We don't want nodes doing this, so we make a frozen copy. - - config = services.configuration.get_config() - # TODO(psyche): If config cannot be changed at runtime, should we cache this? - frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) - return frozen_config - - self.get = get - - -class UtilInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def sd_step_callback( - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - """ - The step callback emits a progress event with the current step, the total number of - steps, a preview image, and some other internal metadata. + :param conditioning_context_data: The conditioning data to save. + """ + + # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. + # + # See comment for `LatentsInterface.save` for more info about this method (it's very + # similar). - This should be called after each denoising step. + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" + self._services.latents.save( + name=name, + data=conditioning_data, # type: ignore [arg-type] + ) + return name - :param intermediate_state: The intermediate state of the diffusion pipeline. - :param base_model: The base model for the current denoising step. - """ + def get(self, conditioning_name: str) -> ConditioningFieldData: + """ + Gets conditioning data by name. + + :param conditioning_name: The name of the conditioning data to get. + """ - # The step callback needs access to the events and the invocation queue services, but this - # represents a dangerous level of access. - # - # We wrap the step callback so that nodes do not have direct access to these services. + return self._services.latents.get(conditioning_name) # type: ignore [return-value] + + +class ModelsInterface(InvocationContextInterface): + def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + """ + Checks if a model exists. + + :param model_name: The name of the model to check. + :param base_model: The base model of the model to check. + :param model_type: The type of the model to check. + """ + return self._services.model_manager.model_exists(model_name, base_model, model_type) + + def load( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Loads a model, returning its `ModelInfo` object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + :param submodel: The submodel of the model to get. + """ + + # During this call, the model manager emits events with model loading status. The model + # manager itself has access to the events services, but does not have access to the + # required metadata for the events. + # + # For example, it needs access to the node's ID so that the events can be associated + # with the execution of a specific node. + # + # While this is available within the node, it's tedious to need to pass it in on every + # call. We can avoid that by wrapping the method here. + + return self._services.model_manager.get_model( + model_name, base_model, model_type, submodel, context_data=self._context_data + ) + + def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Gets a model's info, an dict-like object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + """ + return self._services.model_manager.model_info(model_name, base_model, model_type) + + +class ConfigInterface(InvocationContextInterface): + def get(self) -> InvokeAIAppConfig: + """ + Gets the app's config. The config is read-only; attempts to mutate it will raise an error. + """ + + # The config can be changed at runtime. + # + # We don't want nodes doing this, so we make a frozen copy. + + config = self._services.configuration.get_config() + # TODO(psyche): If config cannot be changed at runtime, should we cache this? + frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) + return frozen_config + + +class UtilInterface(InvocationContextInterface): + def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: + """ + The step callback emits a progress event with the current step, the total number of + steps, a preview image, and some other internal metadata. + + This should be called after each denoising step. + + :param intermediate_state: The intermediate state of the diffusion pipeline. + :param base_model: The base model for the current denoising step. + """ - stable_diffusion_step_callback( - context_data=context_data, - intermediate_state=intermediate_state, - base_model=base_model, - invocation_queue=services.queue, - events=services.events, - ) + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. - self.sd_step_callback = sd_step_callback + stable_diffusion_step_callback( + context_data=self._context_data, + intermediate_state=intermediate_state, + base_model=base_model, + invocation_queue=self._services.queue, + events=self._services.events, + ) deprecation_version = "3.7.0" @@ -600,14 +559,14 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ - logger = LoggerInterface(services=services) + logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) latents = LatentsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) - config = ConfigInterface(services=services) + config = ConfigInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data) - boards = BoardsInterface(services=services) + boards = BoardsInterface(services=services, context_data=context_data) ctx = InvocationContext( images=images, From 889a26c5b64fdfdadc9f834b331076084a9cf3b2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:28:29 +1100 Subject: [PATCH 22/67] feat(nodes): cache invocation interface config --- .../app/services/shared/invocation_context.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 54c50bcf76b..99e439ad96d 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -357,19 +357,24 @@ def get_info(self, model_name: str, base_model: BaseModelType, model_type: Model class ConfigInterface(InvocationContextInterface): + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + super().__init__(services, context_data) + # Config cache, only populated at runtime if requested + self._frozen_config: Optional[InvokeAIAppConfig] = None + def get(self) -> InvokeAIAppConfig: """ Gets the app's config. The config is read-only; attempts to mutate it will raise an error. """ - # The config can be changed at runtime. - # - # We don't want nodes doing this, so we make a frozen copy. + if self._frozen_config is None: + # The config is a live pydantic model and can be changed at runtime. + # We don't want nodes doing this, so we make a frozen copy. + self._frozen_config = self._services.configuration.get_config().model_copy( + update={"model_config": ConfigDict(frozen=True)} + ) - config = self._services.configuration.get_config() - # TODO(psyche): If config cannot be changed at runtime, should we cache this? - frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) - return frozen_config + return self._frozen_config class UtilInterface(InvocationContextInterface): From d7b7dcc7fefe30b7c2d4cca578eecbd4ce52de1b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:36:42 +1100 Subject: [PATCH 23/67] feat(nodes): context.__services -> context._services --- invokeai/app/services/shared/invocation_context.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 99e439ad96d..5da85931672 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -463,7 +463,8 @@ def __init__( """Provides methods to interact with boards.""" self.data = data """Provides data about the current queue item and invocation.""" - self.__services = services + self._services = services + """Provides access to the full application services. This is an internal API and may change without warning.""" @property @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) @@ -475,7 +476,7 @@ def services(self) -> InvocationServices: The invocation services. """ - return self.__services + return self._services @property @deprecated( From aff46759f9a8fffdb670347dab73ad68974b7bb4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:39:26 +1100 Subject: [PATCH 24/67] feat(nodes): context.data -> context._data --- .../app/services/shared/invocation_context.py | 38 +++++++++---------- tests/aa_nodes/test_graph_execution_state.py | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 5da85931672..b48a6acc545 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -442,7 +442,7 @@ def __init__( config: ConfigInterface, util: UtilInterface, boards: BoardsInterface, - data: InvocationContextData, + context_data: InvocationContextData, services: InvocationServices, ) -> None: self.images = images @@ -461,8 +461,8 @@ def __init__( """Provides utility methods.""" self.boards = boards """Provides methods to interact with boards.""" - self.data = data - """Provides data about the current queue item and invocation.""" + self._data = context_data + """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" @@ -481,77 +481,77 @@ def services(self) -> InvocationServices: @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.graph_execution_state_api`", "`context.data.session_id`"), + reason=get_deprecation_reason("`context.graph_execution_state_id", "`context._data.session_id`"), ) def graph_execution_state_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.graph_execution_state_api` will be removed in v3.8.0. Use `context.data.session_id` instead. See PLACEHOLDER_URL for details. + `context.graph_execution_state_api` will be removed in v3.8.0. Use `context._data.session_id` instead. See PLACEHOLDER_URL for details. The ID of the session (aka graph execution state). """ - return self.data.session_id + return self._data.session_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_id`", "`context.data.queue_id`"), + reason=get_deprecation_reason("`context.queue_id`", "`context._data.queue_id`"), ) def queue_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.queue_id` will be removed in v3.8.0. Use `context.data.queue_id` instead. See PLACEHOLDER_URL for details. + `context.queue_id` will be removed in v3.8.0. Use `context._data.queue_id` instead. See PLACEHOLDER_URL for details. The ID of the queue. """ - return self.data.queue_id + return self._data.queue_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_item_id`", "`context.data.queue_item_id`"), + reason=get_deprecation_reason("`context.queue_item_id`", "`context._data.queue_item_id`"), ) def queue_item_id(self) -> int: """ **DEPRECATED as of v3.7.0** - `context.queue_item_id` will be removed in v3.8.0. Use `context.data.queue_item_id` instead. See PLACEHOLDER_URL for details. + `context.queue_item_id` will be removed in v3.8.0. Use `context._data.queue_item_id` instead. See PLACEHOLDER_URL for details. The ID of the queue item. """ - return self.data.queue_item_id + return self._data.queue_item_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_batch_id`", "`context.data.batch_id`"), + reason=get_deprecation_reason("`context.queue_batch_id`", "`context._data.batch_id`"), ) def queue_batch_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.queue_batch_id` will be removed in v3.8.0. Use `context.data.batch_id` instead. See PLACEHOLDER_URL for details. + `context.queue_batch_id` will be removed in v3.8.0. Use `context._data.batch_id` instead. See PLACEHOLDER_URL for details. The ID of the batch. """ - return self.data.batch_id + return self._data.batch_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.workflow`", "`context.data.workflow`"), + reason=get_deprecation_reason("`context.workflow`", "`context._data.workflow`"), ) def workflow(self) -> Optional[WorkflowWithoutID]: """ **DEPRECATED as of v3.7.0** - `context.workflow` will be removed in v3.8.0. Use `context.data.workflow` instead. See PLACEHOLDER_URL for details. + `context.workflow` will be removed in v3.8.0. Use `context._data.workflow` instead. See PLACEHOLDER_URL for details. The workflow associated with this queue item, if any. """ - return self.data.workflow + return self._data.workflow def build_invocation_context( @@ -580,7 +580,7 @@ def build_invocation_context( config=config, latents=latents, models=models, - data=context_data, + context_data=context_data, util=util, conditioning=conditioning, services=services, diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 1612cbe7198..aba7c5694f3 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -87,7 +87,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B InvocationContext( conditioning=None, config=None, - data=None, + context_data=None, images=None, latents=None, logger=None, From b1ba18b3d1c70e17e72afc2f7d244dcc2cea3797 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 15:58:46 +1100 Subject: [PATCH 25/67] fix(nodes): do not freeze or cache config in context wrapper - The config is already cached by the config class's `get_config()` method. - The config mutates itself in its `root_path` property getter. Freezing the class makes any attempt to grab a path from the config error. Unfortunately this means we cannot easily freeze the class without fiddling with the inner workings of `InvokeAIAppConfig`, which is outside the scope here. --- .../app/services/shared/invocation_context.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index b48a6acc545..cd88ec876dd 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -357,24 +357,10 @@ def get_info(self, model_name: str, base_model: BaseModelType, model_type: Model class ConfigInterface(InvocationContextInterface): - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - super().__init__(services, context_data) - # Config cache, only populated at runtime if requested - self._frozen_config: Optional[InvokeAIAppConfig] = None - def get(self) -> InvokeAIAppConfig: - """ - Gets the app's config. The config is read-only; attempts to mutate it will raise an error. - """ - - if self._frozen_config is None: - # The config is a live pydantic model and can be changed at runtime. - # We don't want nodes doing this, so we make a frozen copy. - self._frozen_config = self._services.configuration.get_config().model_copy( - update={"model_config": ConfigDict(frozen=True)} - ) + """Gets the app's config.""" - return self._frozen_config + return self._services.configuration.get_config() class UtilInterface(InvocationContextInterface): From e36d925bce497d2c5d2f29fe6b3e2685c4197125 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:14:35 +1100 Subject: [PATCH 26/67] fix(ui): remove original l2i node in HRF graph --- .../web/src/features/nodes/util/graph/addHrfToGraph.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts index 7413302fa57..8a4448833cf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts @@ -314,6 +314,10 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void = ); copyConnectionsToDenoiseLatentsHrf(graph); + // The original l2i node is unnecessary now, remove it + graph.edges = graph.edges.filter((edge) => edge.destination.node_id !== LATENTS_TO_IMAGE); + delete graph.nodes[LATENTS_TO_IMAGE]; + graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = { type: 'l2i', id: LATENTS_TO_IMAGE_HRF_HR, From 1a191c4655e0fa59500f84dd96a23f9beca83ee4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:23:57 +1100 Subject: [PATCH 27/67] remove unused configdict import --- invokeai/app/services/shared/invocation_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cd88ec876dd..8aaa5233afd 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -3,7 +3,6 @@ from deprecated import deprecated from PIL.Image import Image -from pydantic import ConfigDict from torch import Tensor from invokeai.app.invocations.fields import MetadataField, WithMetadata From c16eba78ab630b4e7f2eec74285c21509a81ca37 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:33:55 +1100 Subject: [PATCH 28/67] feat(nodes): add `WithBoard` field helper class This class works the same way as `WithMetadata` - it simply adds a `board` field to the node. The context wrapper function is able to pull the board id from this. This allows image-outputting nodes to get a board field "for free", and have their outputs automatically saved to it. This is a breaking change for node authors who may have a field called `board`, because it makes `board` a reserved field name. I'll look into how to avoid this - maybe by naming this invoke-managed field `_board` to avoid collisions? Supporting changes: - `WithBoard` is added to all image-outputting nodes, giving them the ability to save to board. - Unused, duplicate `WithMetadata` and `WithWorkflow` classes are deleted from `baseinvocation.py`. The "real" versions are in `fields.py`. - Remove `LinearUIOutputInvocation`. Now that all nodes that output images also have a `board` field by default, this node is no longer necessary. See comment here for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629 - Without `LinearUIOutputInvocation`, the `ImagesInferface.update` method is no longer needed, and removed. Note: This commit does not bump all node versions. I will ensure that is done correctly before merging the PR of which this commit is a part. Note: A followup commit will implement the frontend changes to support this change. --- invokeai/app/invocations/baseinvocation.py | 33 +------- .../controlnet_image_processors.py | 12 ++- invokeai/app/invocations/cv.py | 4 +- invokeai/app/invocations/facetools.py | 4 +- invokeai/app/invocations/fields.py | 16 ++++ invokeai/app/invocations/image.py | 76 ++++++------------- invokeai/app/invocations/infill.py | 12 +-- invokeai/app/invocations/latent.py | 3 +- invokeai/app/invocations/primitives.py | 4 +- invokeai/app/invocations/tiles.py | 4 +- invokeai/app/invocations/upscale.py | 4 +- .../app/services/shared/invocation_context.py | 40 +++------- 12 files changed, 78 insertions(+), 134 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index df0596c9a15..3243714937f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -17,11 +17,8 @@ from pydantic_core import PydanticUndefined from invokeai.app.invocations.fields import ( - FieldDescriptions, FieldKind, Input, - InputFieldJSONSchemaExtra, - MetadataField, ) from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.shared.invocation_context import InvocationContext @@ -306,9 +303,7 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi "workflow", } -RESERVED_INPUT_FIELD_NAMES = { - "metadata", -} +RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"} RESERVED_OUTPUT_FIELD_NAMES = {"type"} @@ -518,29 +513,3 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper - - -class WithMetadata(BaseModel): - """ - Inherit from this class if your node needs a metadata input field. - """ - - metadata: Optional[MetadataField] = Field( - default=None, - description=FieldDescriptions.metadata, - json_schema_extra=InputFieldJSONSchemaExtra( - field_kind=FieldKind.Internal, - input=Input.Connection, - orig_required=False, - ).model_dump(exclude_none=True), - ) - - -class WithWorkflow: - workflow = None - - def __init_subclass__(cls) -> None: - logger.warn( - f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." - ) - super().__init_subclass__() diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index f8bdf14117c..37954c1097e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,7 +25,15 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + OutputField, + WithBoard, + WithMetadata, +) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -135,7 +143,7 @@ def invoke(self, context: InvocationContext) -> ControlOutput: # This invocation exists for other invocations to subclass it - do not register with @invocation! -class ImageProcessorInvocation(BaseInvocation, WithMetadata): +class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): """Base class for invocations that preprocess images for ControlNet""" image: ImageField = InputField(description="The image to process") diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 1ebabf5e064..8174f19b64c 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -10,11 +10,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1") -class CvInpaintInvocation(BaseInvocation, WithMetadata): +class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard): """Simple inpaint using opencv.""" image: ImageField = InputField(description="The image to inpaint") diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index a1702d6517c..fed2ed5e4f2 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -16,7 +16,7 @@ invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext @@ -619,7 +619,7 @@ def invoke(self, context: InvocationContext) -> FaceMaskOutput: @invocation( "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1" ) -class FaceIdentifierInvocation(BaseInvocation, WithMetadata): +class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" image: ImageField = InputField(description="Image to face detect") diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 8879f760770..c42d2f83120 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -280,6 +280,22 @@ def __init_subclass__(cls) -> None: super().__init_subclass__() +class WithBoard(BaseModel): + """ + Inherit from this class if your node needs a board input field. + """ + + board: Optional["BoardField"] = Field( + default=None, + description=FieldDescriptions.board, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Direct, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + class OutputFieldJSONSchemaExtra(BaseModel): """ Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 7b74e4d96d4..f5ad5515a68 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -8,12 +8,11 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps from invokeai.app.invocations.fields import ( - BoardField, ColorField, FieldDescriptions, ImageField, - Input, InputField, + WithBoard, WithMetadata, ) from invokeai.app.invocations.primitives import ImageOutput @@ -55,7 +54,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class BlankImageInvocation(BaseInvocation, WithMetadata): +class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Creates a blank image and forwards it to the pipeline""" width: int = InputField(default=512, description="The width of the image") @@ -78,7 +77,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageCropInvocation(BaseInvocation, WithMetadata): +class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard): """Crops an image to a specified box. The box can be outside of the image.""" image: ImageField = InputField(description="The image to crop") @@ -149,7 +148,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImagePasteInvocation(BaseInvocation, WithMetadata): +class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard): """Pastes an image into another image.""" base_image: ImageField = InputField(description="The base image") @@ -196,7 +195,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): +class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): """Extracts the alpha channel of an image as a mask.""" image: ImageField = InputField(description="The image to create the mask from") @@ -221,7 +220,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageMultiplyInvocation(BaseInvocation, WithMetadata): +class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" image1: ImageField = InputField(description="The first image to multiply") @@ -248,7 +247,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelInvocation(BaseInvocation, WithMetadata): +class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard): """Gets a channel from an image.""" image: ImageField = InputField(description="The image to get the channel from") @@ -274,7 +273,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageConvertInvocation(BaseInvocation, WithMetadata): +class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard): """Converts an image to a different mode.""" image: ImageField = InputField(description="The image to convert") @@ -297,7 +296,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageBlurInvocation(BaseInvocation, WithMetadata): +class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Blurs an image""" image: ImageField = InputField(description="The image to blur") @@ -326,7 +325,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: version="1.2.1", classification=Classification.Beta, ) -class UnsharpMaskInvocation(BaseInvocation, WithMetadata): +class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard): """Applies an unsharp mask filter to an image""" image: ImageField = InputField(description="The image to use") @@ -394,7 +393,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageResizeInvocation(BaseInvocation, WithMetadata): +class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard): """Resizes an image to specific dimensions""" image: ImageField = InputField(description="The image to resize") @@ -424,7 +423,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageScaleInvocation(BaseInvocation, WithMetadata): +class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard): """Scales an image by a factor""" image: ImageField = InputField(description="The image to scale") @@ -459,7 +458,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageLerpInvocation(BaseInvocation, WithMetadata): +class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard): """Linear interpolation of all pixels of an image""" image: ImageField = InputField(description="The image to lerp") @@ -486,7 +485,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): +class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard): """Inverse linear interpolation of all pixels of an image""" image: ImageField = InputField(description="The image to lerp") @@ -513,7 +512,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): +class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Add blur to NSFW-flagged images""" image: ImageField = InputField(description="The image to check") @@ -548,7 +547,7 @@ def _get_caution_img(self) -> Image.Image: category="image", version="1.2.1", ) -class ImageWatermarkInvocation(BaseInvocation, WithMetadata): +class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard): """Add an invisible watermark to an image""" image: ImageField = InputField(description="The image to check") @@ -569,7 +568,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskEdgeInvocation(BaseInvocation, WithMetadata): +class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): """Applies an edge mask to an image""" image: ImageField = InputField(description="The image to apply the mask to") @@ -608,7 +607,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskCombineInvocation(BaseInvocation, WithMetadata): +class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" mask1: ImageField = InputField(description="The first mask to combine") @@ -632,7 +631,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ColorCorrectInvocation(BaseInvocation, WithMetadata): +class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard): """ Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. @@ -736,7 +735,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): +class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard): """Adjusts the Hue of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -825,7 +824,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): +class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard): """Add or subtract a value from a specific color channel of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -881,7 +880,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): +class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): """Scale a specific color channel of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -926,41 +925,14 @@ def invoke(self, context: InvocationContext) -> ImageOutput: version="1.2.1", use_cache=False, ) -class SaveImageInvocation(BaseInvocation, WithMetadata): +class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Saves an image. Unlike an image primitive, this invocation stores a copy of the image.""" image: ImageField = InputField(description=FieldDescriptions.image) - board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) - image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) - - return ImageOutput.build(image_dto) - - -@invocation( - "linear_ui_output", - title="Linear UI Image Output", - tags=["primitives", "image"], - category="primitives", - version="1.0.2", - use_cache=False, -) -class LinearUIOutputInvocation(BaseInvocation, WithMetadata): - """Handles Linear UI Image Outputting tasks.""" - - image: ImageField = InputField(description=FieldDescriptions.image) - board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - - def invoke(self, context: InvocationContext) -> ImageOutput: - image_dto = context.images.get_dto(self.image.image_name) - - image_dto = context.images.update( - image_name=self.image.image_name, - board_id=self.board.board_id if self.board else None, - is_intermediate=self.is_intermediate, - ) + image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index b007edd9e42..53f6f4732fe 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -15,7 +15,7 @@ from invokeai.backend.image_util.patchmatch import PatchMatch from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES @@ -121,7 +121,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] @invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class InfillColorInvocation(BaseInvocation, WithMetadata): +class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image with a solid color""" image: ImageField = InputField(description="The image to infill") @@ -144,7 +144,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") -class InfillTileInvocation(BaseInvocation, WithMetadata): +class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image with tiles of the image""" image: ImageField = InputField(description="The image to infill") @@ -170,7 +170,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation( "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1" ) -class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): +class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using the PatchMatch algorithm""" image: ImageField = InputField(description="The image to infill") @@ -209,7 +209,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class LaMaInfillInvocation(BaseInvocation, WithMetadata): +class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using the LaMa model""" image: ImageField = InputField(description="The image to infill") @@ -225,7 +225,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class CV2InfillInvocation(BaseInvocation, WithMetadata): +class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using OpenCV Inpainting""" image: ImageField = InputField(description="The image to infill") diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5e36e73ec8f..5449ec9af7a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -33,6 +33,7 @@ LatentsField, OutputField, UIType, + WithBoard, WithMetadata, ) from invokeai.app.invocations.ip_adapter import IPAdapterField @@ -762,7 +763,7 @@ def _lora_loader(): category="latents", version="1.2.1", ) -class LatentsToImageInvocation(BaseInvocation, WithMetadata): +class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Generates an image from latents.""" latents: LatentsField = InputField( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index c90d3230b2b..a77939943ae 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -255,9 +255,7 @@ class ImageCollectionOutput(BaseInvocationOutput): @invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1") -class ImageInvocation( - BaseInvocation, -): +class ImageInvocation(BaseInvocation): """An image primitive value""" image: ImageField = InputField(description="The image to load") diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index 19ece423761..cb5373bbf75 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -11,7 +11,7 @@ invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithBoard, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.tiles.tiles import ( @@ -232,7 +232,7 @@ def invoke(self, context: InvocationContext) -> PairTileImageOutput: version="1.1.0", classification=Classification.Beta, ) -class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): +class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Merge multiple tile images into a single image.""" # Inputs diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 71ef7ca3aa0..2e2a6ce8813 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -16,7 +16,7 @@ from invokeai.backend.util.devices import choose_torch_device from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata # TODO: Populate this from disk? # TODO: Use model manager to load? @@ -32,7 +32,7 @@ @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1") -class ESRGANInvocation(BaseInvocation, WithMetadata): +class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): """Upscales an image using RealESRGAN.""" image: ImageField = InputField(description="The input image") diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 8aaa5233afd..97a62246fbc 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -5,10 +5,10 @@ from PIL.Image import Image from torch import Tensor -from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID @@ -158,7 +158,9 @@ def save( If the current queue item has a workflow or metadata, it is automatically saved with the image. :param image: The image to save, as a PIL image. - :param board_id: The board ID to add the image to, if it should be added. + :param board_id: The board ID to add the image to, if it should be added. It the invocation \ + inherits from `WithBoard`, that board will be used automatically. **Use this only if \ + you want to override or provide a board manually!** :param image_category: The category of the image. Only the GENERAL category is added \ to the gallery. :param metadata: The metadata to save with the image, if it should have any. If the \ @@ -173,11 +175,15 @@ def save( else metadata ) + # If the invocation inherits WithBoard, use that. Else, use the board_id passed in. + board_ = self._context_data.invocation.board if isinstance(self._context_data.invocation, WithBoard) else None + board_id_ = board_.board_id if board_ is not None else board_id + return self._services.images.create( image=image, is_intermediate=self._context_data.invocation.is_intermediate, image_category=image_category, - board_id=board_id, + board_id=board_id_, metadata=metadata_, image_origin=ResourceOrigin.INTERNAL, workflow=self._context_data.workflow, @@ -209,32 +215,6 @@ def get_dto(self, image_name: str) -> ImageDTO: """ return self._services.images.get_dto(image_name) - def update( - self, - image_name: str, - board_id: Optional[str] = None, - is_intermediate: Optional[bool] = False, - ) -> ImageDTO: - """ - Updates an image, returning its updated DTO. - - It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - - If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to - get the updated image. - - :param image_name: The name of the image to update. - :param board_id: The board ID to add the image to, if it should be added. - :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. - """ - if is_intermediate is not None: - self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) - if board_id is None: - self._services.board_images.remove_image_from_board(image_name) - else: - self._services.board_images.add_image_to_board(image_name, board_id) - return self._services.images.get_dto(image_name) - class LatentsInterface(InvocationContextInterface): def save(self, tensor: Tensor) -> str: From a3faa3792a91a8898674b160aea12e768bab2ac3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:34:40 +1100 Subject: [PATCH 29/67] chore(ui): regen types --- .../frontend/web/src/services/api/schema.ts | 223 ++++++++---------- 1 file changed, 96 insertions(+), 127 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index da036b6d40a..45358ed97d5 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -968,6 +968,8 @@ export type components = { * @description Creates a blank image and forwards it to the pipeline */ BlankImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -1860,6 +1862,8 @@ export type components = { * @description Infills transparent areas of an image using OpenCV Inpainting */ CV2InfillInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2095,6 +2099,8 @@ export type components = { * @description Canny edge detection for ControlNet */ CannyImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2482,6 +2488,8 @@ export type components = { * using a mask to only color-correct certain regions of the target image. */ ColorCorrectInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2590,6 +2598,8 @@ export type components = { * @description Generates a color map from the provided image */ ColorMapImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2797,6 +2807,8 @@ export type components = { * @description Applies content shuffle processing to image */ ContentShuffleImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3442,6 +3454,8 @@ export type components = { * @description Simple inpaint using opencv. */ CvInpaintInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3677,6 +3691,8 @@ export type components = { * @description Generates a depth map based on the Depth Anything algorithm */ DepthAnythingImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3910,6 +3926,8 @@ export type components = { * @description Upscales an image using RealESRGAN. */ ESRGANInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4041,6 +4059,8 @@ export type components = { * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. */ FaceIdentifierInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4873,6 +4893,8 @@ export type components = { * @description Applies HED edge detection to image */ HedImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5324,6 +5346,8 @@ export type components = { * @description Blurs an image */ ImageBlurInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5382,6 +5406,8 @@ export type components = { * @description Gets a channel from an image. */ ImageChannelInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5422,6 +5448,8 @@ export type components = { * @description Scale a specific color channel of an image. */ ImageChannelMultiplyInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5473,6 +5501,8 @@ export type components = { * @description Add or subtract a value from a specific color channel of an image. */ ImageChannelOffsetInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5640,6 +5670,8 @@ export type components = { * @description Converts an image to a different mode. */ ImageConvertInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5680,6 +5712,8 @@ export type components = { * @description Crops an image to a specified box. The box can be outside of the image. */ ImageCropInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5949,6 +5983,8 @@ export type components = { * @description Adjusts the Hue of an image. */ ImageHueAdjustmentInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5988,6 +6024,8 @@ export type components = { * @description Inverse linear interpolation of all pixels of an image */ ImageInverseLerpInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6064,6 +6102,8 @@ export type components = { * @description Linear interpolation of all pixels of an image */ ImageLerpInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6109,6 +6149,8 @@ export type components = { * @description Multiplies two images together using `PIL.ImageChops.multiply()`. */ ImageMultiplyInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6144,6 +6186,8 @@ export type components = { * @description Add blur to NSFW-flagged images */ ImageNSFWBlurInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6252,6 +6296,8 @@ export type components = { * @description Pastes an image into another image. */ ImagePasteInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6337,6 +6383,8 @@ export type components = { * @description Resizes an image to specific dimensions */ ImageResizeInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6446,6 +6494,8 @@ export type components = { * @description Scales an image by a factor */ ImageScaleInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6621,6 +6671,8 @@ export type components = { * @description Add an invisible watermark to an image */ ImageWatermarkInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6676,6 +6728,8 @@ export type components = { * @description Infills transparent areas of an image with a solid color */ InfillColorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6719,6 +6773,8 @@ export type components = { * @description Infills transparent areas of an image using the PatchMatch algorithm */ InfillPatchMatchInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6765,6 +6821,8 @@ export type components = { * @description Infills transparent areas of an image with tiles of the image */ InfillTileInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7065,6 +7123,8 @@ export type components = { * @description Infills transparent areas of an image using the LaMa model */ LaMaInfillInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7093,96 +7153,6 @@ export type components = { */ type: "infill_lama"; }; - /** - * Latent Consistency MonoNode - * @description Wrapper node around diffusers LatentConsistencyTxt2ImgPipeline - */ - LatentConsistencyInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Prompt - * @description The prompt to use - */ - prompt?: string; - /** - * Num Inference Steps - * @description The number of inference steps to use, 4-8 recommended - * @default 8 - */ - num_inference_steps?: number; - /** - * Guidance Scale - * @description The guidance scale to use - * @default 8 - */ - guidance_scale?: number; - /** - * Batches - * @description The number of batches to use - * @default 1 - */ - batches?: number; - /** - * Images Per Batch - * @description The number of images per batch to use - * @default 1 - */ - images_per_batch?: number; - /** - * Seeds - * @description List of noise seeds to use - */ - seeds?: number[]; - /** - * Lcm Origin Steps - * @description The lcm origin steps to use - * @default 50 - */ - lcm_origin_steps?: number; - /** - * Width - * @description The width to use - * @default 512 - */ - width?: number; - /** - * Height - * @description The height to use - * @default 512 - */ - height?: number; - /** - * Precision - * @description floating point precision - * @default fp16 - * @enum {string} - */ - precision?: "fp16" | "fp32"; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"]; - /** - * type - * @default latent_consistency_mononode - * @constant - */ - type: "latent_consistency_mononode"; - }; /** * Latents Collection Primitive * @description A collection of latents tensor primitive values @@ -7310,6 +7280,8 @@ export type components = { * @description Generates an image from latents. */ LatentsToImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7357,6 +7329,8 @@ export type components = { * @description Applies leres processing to image */ LeresImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7441,46 +7415,13 @@ export type components = { /** @description Type of commercial use allowed or 'No' if no commercial use is allowed. */ AllowCommercialUse?: components["schemas"]["CommercialUsage"]; }; - /** - * Linear UI Image Output - * @description Handles Linear UI Image Outputting tasks. - */ - LinearUIOutputInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default false - */ - use_cache?: boolean; - /** @description The image to process */ - image?: components["schemas"]["ImageField"]; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** - * type - * @default linear_ui_output - * @constant - */ - type: "linear_ui_output"; - }; /** * Lineart Anime Processor * @description Applies line art anime processing to image */ LineartAnimeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7526,6 +7467,8 @@ export type components = { * @description Applies line art processing to image */ LineartImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7954,6 +7897,8 @@ export type components = { * @description Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`. */ MaskCombineInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7989,6 +7934,8 @@ export type components = { * @description Applies an edge mask to an image */ MaskEdgeInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8042,6 +7989,8 @@ export type components = { * @description Extracts the alpha channel of an image as a mask. */ MaskFromAlphaInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8122,6 +8071,8 @@ export type components = { * @description Applies mediapipe face processing to image */ MediapipeFaceProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8238,6 +8189,8 @@ export type components = { * @description Merge multiple tile images into a single image. */ MergeTilesToImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8404,6 +8357,8 @@ export type components = { * @description Applies Midas depth processing to image */ MidasDepthImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8449,6 +8404,8 @@ export type components = { * @description Applies MLSD processing to image */ MlsdImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8961,6 +8918,8 @@ export type components = { * @description Applies NormalBae processing to image */ NormalbaeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -9591,6 +9550,8 @@ export type components = { * @description Applies PIDI processing to image */ PidiImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10443,6 +10404,8 @@ export type components = { * @description Saves an image. Unlike an image primitive, this invocation stores a copy of the image. */ SaveImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10464,8 +10427,6 @@ export type components = { use_cache?: boolean; /** @description The image to process */ image?: components["schemas"]["ImageField"]; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"]; /** * type * @default save_image @@ -10651,6 +10612,8 @@ export type components = { * @description Applies segment anything processing to image */ SegmentAnythingProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12189,6 +12152,8 @@ export type components = { * @description Tile resampler processor */ TileResamplerProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12378,6 +12343,8 @@ export type components = { * @description Applies an unsharp mask filter to an image */ UnsharpMaskInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12846,6 +12813,8 @@ export type components = { * @description Applies Zoe depth processing to image */ ZoeDepthImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** From 8e2b61e19f4911d372d2971c2beb61c834ef5a29 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:41:24 +1100 Subject: [PATCH 30/67] feat(ui): revise graphs to not use `LinearUIOutputInvocation` See this comment for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629 - Remove this now-unnecessary node from all graphs - Update graphs' terminal image-outputting nodes' `is_intermediate` and `board` fields appropriately - Add util function to prepare the `board` field, tidy the utils - Update `socketInvocationComplete` listener to work correctly with this change I've manually tested all graph permutations that were changed (I think this is all...) to ensure images go to the gallery as expected: - ad-hoc upscaling - t2i w/ sd1.5 - t2i w/ sd1.5 & hrf - t2i w/ sdxl - t2i w/ sdxl + refiner - i2i w/ sd1.5 - i2i w/ sdxl - i2i w/ sdxl + refiner - canvas t2i w/ sd1.5 - canvas t2i w/ sdxl - canvas t2i w/ sdxl + refiner - canvas i2i w/ sd1.5 - canvas i2i w/ sdxl - canvas i2i w/ sdxl + refiner - canvas inpaint w/ sd1.5 - canvas inpaint w/ sdxl - canvas inpaint w/ sdxl + refiner - canvas outpaint w/ sd1.5 - canvas outpaint w/ sdxl - canvas outpaint w/ sdxl + refiner --- .../socketio/socketInvocationComplete.ts | 9 +-- .../listeners/upscaleRequested.ts | 6 +- .../nodes/util/graph/addHrfToGraph.ts | 4 +- .../nodes/util/graph/addLinearUIOutputNode.ts | 78 ------------------- .../nodes/util/graph/addNSFWCheckerToGraph.ts | 4 +- .../nodes/util/graph/addSDXLRefinerToGraph.ts | 2 +- .../nodes/util/graph/addWatermarkerToGraph.ts | 9 +-- .../util/graph/buildAdHocUpscaleGraph.ts | 40 +++------- .../graph/buildCanvasImageToImageGraph.ts | 13 ++-- .../util/graph/buildCanvasInpaintGraph.ts | 7 +- .../util/graph/buildCanvasOutpaintGraph.ts | 7 +- .../graph/buildCanvasSDXLImageToImageGraph.ts | 8 +- .../util/graph/buildCanvasSDXLInpaintGraph.ts | 8 +- .../graph/buildCanvasSDXLOutpaintGraph.ts | 8 +- .../graph/buildCanvasSDXLTextToImageGraph.ts | 11 ++- .../util/graph/buildCanvasTextToImageGraph.ts | 10 +-- .../graph/buildLinearImageToImageGraph.ts | 7 +- .../graph/buildLinearSDXLImageToImageGraph.ts | 8 +- .../graph/buildLinearSDXLTextToImageGraph.ts | 8 +- .../util/graph/buildLinearTextToImageGraph.ts | 7 +- .../features/nodes/util/graph/constants.ts | 1 - .../nodes/util/graph/getSDXLStylePrompt.ts | 11 --- .../nodes/util/graph/graphBuilderUtils.ts | 38 +++++++++ .../frontend/web/src/services/api/types.ts | 1 - 24 files changed, 108 insertions(+), 197 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index d49f35cd2ab..75fa9e10949 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -4,7 +4,7 @@ import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { isImageOutput } from 'features/nodes/types/common'; -import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants'; +import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter } from 'services/api/util'; @@ -24,10 +24,9 @@ export const addInvocationCompleteEventListener = () => { const { data } = action.payload; log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); - const { result, node, queue_batch_id, source_node_id } = data; - + const { result, node, queue_batch_id } = data; // This complete event has an associated image output - if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) { + if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { const { image_name } = result.image; const { canvas, gallery } = getState(); @@ -42,7 +41,7 @@ export const addInvocationCompleteEventListener = () => { imageDTORequest.unsubscribe(); // Add canvas images to the staging area - if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) { + if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) { dispatch(addImageToStagingArea(imageDTO)); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts index 46f55ef21ff..ab989301796 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts @@ -39,16 +39,12 @@ export const addUpscaleRequestedListener = () => { return; } - const { esrganModelName } = state.postprocessing; - const { autoAddBoardId } = state.gallery; - const enqueueBatchArg: BatchConfig = { prepend: true, batch: { graph: buildAdHocUpscaleGraph({ image_name, - esrganModelName, - autoAddBoardId, + state, }), runs: 1, }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts index 8a4448833cf..5632cfd1122 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import type { DenoiseLatentsInvocation, @@ -322,7 +323,8 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void = type: 'l2i', id: LATENTS_TO_IMAGE_HRF_HR, fp32: originalLatentsToImageNode?.fp32, - is_intermediate: true, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.edges.push( { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts deleted file mode 100644 index 5c78ad804ed..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts +++ /dev/null @@ -1,78 +0,0 @@ -import type { RootState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import type { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; - -import { - CANVAS_OUTPUT, - LATENTS_TO_IMAGE, - LATENTS_TO_IMAGE_HRF_HR, - LINEAR_UI_OUTPUT, - NSFW_CHECKER, - WATERMARKER, -} from './constants'; - -/** - * Set the `use_cache` field on the linear/canvas graph's final image output node to False. - */ -export const addLinearUIOutputNode = (state: RootState, graph: NonNullableGraph): void => { - const activeTabName = activeTabNameSelector(state); - const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; - const { autoAddBoardId } = state.gallery; - - const linearUIOutputNode: LinearUIOutputInvocation = { - id: LINEAR_UI_OUTPUT, - type: 'linear_ui_output', - is_intermediate, - use_cache: false, - board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId }, - }; - - graph.nodes[LINEAR_UI_OUTPUT] = linearUIOutputNode; - - const destination = { - node_id: LINEAR_UI_OUTPUT, - field: 'image', - }; - - if (WATERMARKER in graph.nodes) { - graph.edges.push({ - source: { - node_id: WATERMARKER, - field: 'image', - }, - destination, - }); - } else if (NSFW_CHECKER in graph.nodes) { - graph.edges.push({ - source: { - node_id: NSFW_CHECKER, - field: 'image', - }, - destination, - }); - } else if (CANVAS_OUTPUT in graph.nodes) { - graph.edges.push({ - source: { - node_id: CANVAS_OUTPUT, - field: 'image', - }, - destination, - }); - } else if (LATENTS_TO_IMAGE_HRF_HR in graph.nodes) { - graph.edges.push({ - source: { - node_id: LATENTS_TO_IMAGE_HRF_HR, - field: 'image', - }, - destination, - }); - } else if (LATENTS_TO_IMAGE in graph.nodes) { - graph.edges.push({ - source: { - node_id: LATENTS_TO_IMAGE, - field: 'image', - }, - destination, - }); - } -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts index 4a8e77abfa0..35fc3246890 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts @@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store'; import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants'; +import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; export const addNSFWCheckerToGraph = ( state: RootState, @@ -21,7 +22,8 @@ export const addNSFWCheckerToGraph = ( const nsfwCheckerNode: ImageNSFWBlurInvocation = { id: NSFW_CHECKER, type: 'img_nsfw', - is_intermediate: true, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts index 708353e4d6a..fc4d998969d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts @@ -24,7 +24,7 @@ import { SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getSDXLStylePrompts } from './graphBuilderUtils'; import { upsertMetadata } from './metadata'; export const addSDXLRefinerToGraph = ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts index 99c5c07be47..61beb11df49 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts @@ -1,5 +1,4 @@ import type { RootState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import type { ImageNSFWBlurInvocation, ImageWatermarkInvocation, @@ -8,16 +7,13 @@ import type { } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants'; +import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; export const addWatermarkerToGraph = ( state: RootState, graph: NonNullableGraph, nodeIdToAddTo = LATENTS_TO_IMAGE ): void => { - const activeTabName = activeTabNameSelector(state); - - const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; - const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined; const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined; @@ -30,7 +26,8 @@ export const addWatermarkerToGraph = ( const watermarkerNode: ImageWatermarkInvocation = { id: WATERMARKER, type: 'img_watermark', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.nodes[WATERMARKER] = watermarkerNode; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts index fa20206d91f..52c09b1db06 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts @@ -1,51 +1,33 @@ -import type { BoardId } from 'features/gallery/store/types'; -import type { ParamESRGANModelName } from 'features/parameters/store/postprocessingSlice'; -import type { ESRGANInvocation, Graph, LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; +import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; +import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types'; -import { ESRGAN, LINEAR_UI_OUTPUT } from './constants'; +import { ESRGAN } from './constants'; import { addCoreMetadataNode, upsertMetadata } from './metadata'; type Arg = { image_name: string; - esrganModelName: ParamESRGANModelName; - autoAddBoardId: BoardId; + state: RootState; }; -export const buildAdHocUpscaleGraph = ({ image_name, esrganModelName, autoAddBoardId }: Arg): Graph => { +export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => { + const { esrganModelName } = state.postprocessing; + const realesrganNode: ESRGANInvocation = { id: ESRGAN, type: 'esrgan', image: { image_name }, model_name: esrganModelName, - is_intermediate: true, - }; - - const linearUIOutputNode: LinearUIOutputInvocation = { - id: LINEAR_UI_OUTPUT, - type: 'linear_ui_output', - use_cache: false, - is_intermediate: false, - board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId }, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; const graph: NonNullableGraph = { id: `adhoc-esrgan-graph`, nodes: { [ESRGAN]: realesrganNode, - [LINEAR_UI_OUTPUT]: linearUIOutputNode, }, - edges: [ - { - source: { - node_id: ESRGAN, - field: 'image', - }, - destination: { - node_id: LINEAR_UI_OUTPUT, - field: 'image', - }, - }, - ], + edges: [], }; addCoreMetadataNode(graph, {}, ESRGAN); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts index 3002e05441b..bc6a83f4fa6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -132,7 +132,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima [CANVAS_OUTPUT]: { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -242,7 +243,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -284,7 +286,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -355,7 +358,5 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts index bb52a44a8e4..d983b9cf4f5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { CreateDenoiseMaskInvocation, ImageBlurInvocation, @@ -12,7 +13,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -191,7 +191,8 @@ export const buildCanvasInpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), reference: canvasInitImage, use_cache: false, }, @@ -663,7 +664,5 @@ export const buildCanvasInpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts index b82b55cfee7..1d028943818 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageDTO, ImageToLatentsInvocation, @@ -11,7 +12,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -200,7 +200,8 @@ export const buildCanvasOutpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -769,7 +770,5 @@ export const buildCanvasOutpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts index 1b586371a02..58269afce3f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts @@ -4,7 +4,6 @@ import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'servi import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -26,7 +25,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -246,7 +245,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -368,7 +368,5 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts index 00fea9a37e6..5902dee2fc4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts @@ -12,7 +12,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -44,7 +43,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; /** * Builds the Canvas tab's Inpaint graph. @@ -190,7 +189,8 @@ export const buildCanvasSDXLInpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), reference: canvasInitImage, use_cache: false, }, @@ -687,7 +687,5 @@ export const buildCanvasSDXLInpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts index f85760d8f2e..7a78750e8d2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts @@ -11,7 +11,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -46,7 +45,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; /** * Builds the Canvas tab's Outpaint graph. @@ -199,7 +198,8 @@ export const buildCanvasSDXLOutpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -786,7 +786,5 @@ export const buildCanvasSDXLOutpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts index 91d9da4cb5d..22da39c67da 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts @@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -24,7 +23,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -222,7 +221,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -254,7 +254,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -330,7 +331,5 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts index 967dd3ff4a5..93f0470c7ad 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -211,7 +211,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -243,7 +244,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -310,7 +312,5 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts index c76776d94d3..d1f1546b23b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -117,7 +117,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }, [DENOISE_LATENTS]: { type: 'denoise_latents', @@ -358,7 +359,5 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts index 9ae602bcacb..de4ad7ceceb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts @@ -4,7 +4,6 @@ import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -25,7 +24,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -120,7 +119,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }, [SDXL_DENOISE_LATENTS]: { type: 'denoise_latents', @@ -380,7 +380,5 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts index 222dc1a3595..58b97b07c75 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts @@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -23,7 +22,7 @@ import { SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => { @@ -120,7 +119,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -281,7 +281,5 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts index 0a45d91debc..b2b84cfdad7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts @@ -1,11 +1,11 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addHrfToGraph } from './addHrfToGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -119,7 +119,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -267,7 +268,5 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts index 363d3191210..767bf25df0a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts @@ -9,7 +9,6 @@ export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr'; export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf'; export const RESIZE_HRF = 'resize_hrf'; export const ESRGAN_HRF = 'esrgan_hrf'; -export const LINEAR_UI_OUTPUT = 'linear_ui_output'; export const NSFW_CHECKER = 'nsfw_checker'; export const WATERMARKER = 'invisible_watermark'; export const NOISE = 'noise'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts b/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts deleted file mode 100644 index e1cd8518fdd..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts +++ /dev/null @@ -1,11 +0,0 @@ -import type { RootState } from 'app/store/store'; - -export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => { - const { positivePrompt, negativePrompt } = state.generation; - const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl; - - return { - positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt, - negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt, - }; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts new file mode 100644 index 00000000000..cb6fc9acf1e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts @@ -0,0 +1,38 @@ +import type { RootState } from 'app/store/store'; +import type { BoardField } from 'features/nodes/types/common'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; + +/** + * Gets the board field, based on the autoAddBoardId setting. + */ +export const getBoardField = (state: RootState): BoardField | undefined => { + const { autoAddBoardId } = state.gallery; + if (autoAddBoardId === 'none') { + return undefined; + } + return { board_id: autoAddBoardId }; +}; + +/** + * Gets the SDXL style prompts, based on the concat setting. + */ +export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => { + const { positivePrompt, negativePrompt } = state.generation; + const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl; + + return { + positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt, + negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt, + }; +}; + +/** + * Gets the is_intermediate field, based on the active tab and shouldAutoSave setting. + */ +export const getIsIntermediate = (state: RootState) => { + const activeTabName = activeTabNameSelector(state); + if (activeTabName === 'unifiedCanvas') { + return !state.canvas.shouldAutoSave; + } + return false; +}; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 1382fbe275a..55ff808b404 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -132,7 +132,6 @@ export type DivideInvocation = s['DivideInvocation']; export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation']; export type ImageWatermarkInvocation = s['ImageWatermarkInvocation']; export type SeamlessModeInvocation = s['SeamlessModeInvocation']; -export type LinearUIOutputInvocation = s['LinearUIOutputInvocation']; export type MetadataInvocation = s['MetadataInvocation']; export type CoreMetadataInvocation = s['CoreMetadataInvocation']; export type MetadataItemInvocation = s['MetadataItemInvocation']; From a5db204629181af352b09ae5af86618ece58b408 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:01:39 +1100 Subject: [PATCH 31/67] tidy(nodes): remove unnecessary, shadowing class attr declarations --- invokeai/app/services/invocation_services.py | 27 -------------------- 1 file changed, 27 deletions(-) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 11a4de99d6e..51bfd5d77a1 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -36,33 +36,6 @@ class InvocationServices: """Services that can be used by invocations""" - # TODO: Just forward-declared everything due to circular dependencies. Fix structure. - board_images: "BoardImagesServiceABC" - board_image_record_storage: "BoardImageRecordStorageBase" - boards: "BoardServiceABC" - board_records: "BoardRecordStorageBase" - configuration: "InvokeAIAppConfig" - events: "EventServiceBase" - graph_execution_manager: "ItemStorageABC[GraphExecutionState]" - images: "ImageServiceABC" - image_records: "ImageRecordStorageBase" - image_files: "ImageFileStorageBase" - latents: "LatentsStorageBase" - logger: "Logger" - model_manager: "ModelManagerServiceBase" - model_records: "ModelRecordServiceBase" - download_queue: "DownloadQueueServiceBase" - model_install: "ModelInstallServiceBase" - processor: "InvocationProcessorABC" - performance_statistics: "InvocationStatsServiceBase" - queue: "InvocationQueueABC" - session_queue: "SessionQueueBase" - session_processor: "SessionProcessorBase" - invocation_cache: "InvocationCacheBase" - names: "NameServiceBase" - urls: "UrlServiceBase" - workflow_records: "WorkflowRecordsStorageBase" - def __init__( self, board_images: "BoardImagesServiceABC", From db6bc7305a4830799a62b21cd3ff46f30bc4f35e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:10:25 +1100 Subject: [PATCH 32/67] fix(nodes): rearrange fields.py to avoid needing forward refs --- invokeai/app/invocations/fields.py | 92 +++++++++++++++--------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index c42d2f83120..40d403c03d9 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -182,6 +182,51 @@ class FieldDescriptions: freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class BoardField(BaseModel): + """A board primitive field""" + + board_id: str = Field(description="The id of the board") + + +class DenoiseMaskField(BaseModel): + """An inpaint mask field""" + + mask_name: str = Field(description="The name of the mask image") + masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + # endregion + + class MetadataField(RootModel): """ Pydantic model for metadata with custom root of type dict[str, Any]. @@ -285,7 +330,7 @@ class WithBoard(BaseModel): Inherit from this class if your node needs a board input field. """ - board: Optional["BoardField"] = Field( + board: Optional[BoardField] = Field( default=None, description=FieldDescriptions.board, json_schema_extra=InputFieldJSONSchemaExtra( @@ -518,48 +563,3 @@ def OutputField( field_kind=FieldKind.Output, ).model_dump(exclude_none=True), ) - - -class ImageField(BaseModel): - """An image primitive field""" - - image_name: str = Field(description="The name of the image") - - -class BoardField(BaseModel): - """A board primitive field""" - - board_id: str = Field(description="The id of the board") - - -class DenoiseMaskField(BaseModel): - """An inpaint mask field""" - - mask_name: str = Field(description="The name of the mask image") - masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - - -class LatentsField(BaseModel): - """A latents tensor primitive field""" - - latents_name: str = Field(description="The name of the latents") - seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") - - -class ColorField(BaseModel): - """A color primitive field""" - - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) - - -class ConditioningField(BaseModel): - """A conditioning tensor primitive value""" - - conditioning_name: str = Field(description="The name of conditioning tensor") - # endregion From 2932652787221019eb1b18e5e6df395b2b5491bb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:11:22 +1100 Subject: [PATCH 33/67] tidy(nodes): delete onnx.py It doesn't work and keeping it updated to prevent the app from starting was getting tedious. Deleted. --- invokeai/app/invocations/onnx.py | 510 ------------------------------- 1 file changed, 510 deletions(-) delete mode 100644 invokeai/app/invocations/onnx.py diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py deleted file mode 100644 index e7b4d3d9fc5..00000000000 --- a/invokeai/app/invocations/onnx.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) - -import inspect - -# from contextlib import ExitStack -from typing import List, Literal, Union - -import numpy as np -import torch -from diffusers.image_processor import VaeImageProcessor -from pydantic import BaseModel, ConfigDict, Field, field_validator -from tqdm import tqdm - -from invokeai.app.invocations.fields import ( - FieldDescriptions, - Input, - InputField, - OutputField, - UIComponent, - UIType, - WithMetadata, -) -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend import BaseModelType, ModelType, SubModelType - -from ...backend.model_management import ONNXModelPatcher -from ...backend.stable_diffusion import PipelineIntermediateState -from ...backend.util import choose_torch_device -from ..util.ti_utils import extract_ti_triggers_from_prompt -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - InvocationContext, - invocation, - invocation_output, -) -from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler -from .model import ClipField, ModelInfo, UNetField, VaeField - -ORT_TO_NP_TYPE = { - "tensor(bool)": np.bool_, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - "tensor(int16)": np.int16, - "tensor(uint16)": np.uint16, - "tensor(int32)": np.int32, - "tensor(uint32)": np.uint32, - "tensor(int64)": np.int64, - "tensor(uint64)": np.uint64, - "tensor(float16)": np.float16, - "tensor(float)": np.float32, - "tensor(double)": np.float64, -} - -PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())] - - -@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0") -class ONNXPromptInvocation(BaseInvocation): - prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) - clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - ) - with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.clip.loras - ] - - ti_list = [] - for trigger in extract_ti_triggers_from_prompt(self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except Exception: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') - if loras or ti_list: - text_encoder.release_session() - with ( - ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), - ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager), - ): - text_encoder.create_session() - - # copy from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153 - text_inputs = tokenizer( - self.prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - """ - untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - """ - - prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (prompt_embeds, None)) - - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) - - -# Text to image -@invocation( - "t2l_onnx", - title="ONNX Text to Latents", - tags=["latents", "inference", "txt2img", "onnx"], - category="latents", - version="1.0.0", -) -class ONNXTextToLatentsInvocation(BaseInvocation): - """Generates latents from conditionings.""" - - positive_conditioning: ConditioningField = InputField( - description=FieldDescriptions.positive_cond, - input=Input.Connection, - ) - negative_conditioning: ConditioningField = InputField( - description=FieldDescriptions.negative_cond, - input=Input.Connection, - ) - noise: LatentsField = InputField( - description=FieldDescriptions.noise, - input=Input.Connection, - ) - steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) - cfg_scale: Union[float, List[float]] = InputField( - default=7.5, - ge=1, - description=FieldDescriptions.cfg_scale, - ) - scheduler: SAMPLER_NAME_VALUES = InputField( - default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler - ) - precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) - unet: UNetField = InputField( - description=FieldDescriptions.unet, - input=Input.Connection, - ) - control: Union[ControlField, list[ControlField]] = InputField( - default=None, - description=FieldDescriptions.control, - ) - # seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", ) - # seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'") - - @field_validator("cfg_scale") - def ge_one(cls, v): - """validate that all cfg_scale values are >= 1""" - if isinstance(v, list): - for i in v: - if i < 1: - raise ValueError("cfg_scale must be greater than 1") - else: - if v < 1: - raise ValueError("cfg_scale must be greater than 1") - return v - - # based on - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: - c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - if isinstance(c, torch.Tensor): - c = c.cpu().numpy() - if isinstance(uc, torch.Tensor): - uc = uc.cpu().numpy() - device = torch.device(choose_torch_device()) - prompt_embeds = np.concatenate([uc, c]) - - latents = context.services.latents.get(self.noise.latents_name) - if isinstance(latents, torch.Tensor): - latents = latents.cpu().numpy() - - # TODO: better execution device handling - latents = latents.astype(ORT_TO_NP_TYPE[self.precision]) - - # get the initial random noise unless the user supplied it - do_classifier_free_guidance = True - # latents_dtype = prompt_embeds.dtype - # latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) - # if latents.shape != latents_shape: - # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=0, # TODO: refactor this node - ) - - def torch2numpy(latent: torch.Tensor): - return latent.cpu().numpy() - - def numpy2torch(latent, device): - return torch.from_numpy(latent).to(device) - - def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - ) - - scheduler.set_timesteps(self.steps) - latents = latents * np.float64(scheduler.init_noise_sigma) - - extra_step_kwargs = {} - if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): - extra_step_kwargs.update( - eta=0.0, - ) - - unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump()) - - with unet_info as unet: # , ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.unet.loras - ] - - if loras: - unet.release_session() - with ONNXModelPatcher.apply_lora_unet(unet, loras): - # TODO: - _, _, h, w = latents.shape - unet.create_session(h, w) - - timestep_dtype = next( - (input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - for i in tqdm(range(len(scheduler.timesteps))): - t = scheduler.timesteps[i] - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = scheduler.step( - numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs - ) - latents = torch2numpy(scheduler_output.prev_sample) - - state = PipelineIntermediateState( - run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample - ) - dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state) - - # call the callback, if provided - # if callback is not None and i % callback_steps == 0: - # callback(i, t, latents) - - torch.cuda.empty_cache() - - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, latents) - # return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) - - -# Latent to image -@invocation( - "l2i_onnx", - title="ONNX Latents to Image", - tags=["latents", "image", "vae", "onnx"], - category="image", - version="1.2.0", -) -class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): - """Generates an image from latents.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.denoised_latents, - input=Input.Connection, - ) - vae: VaeField = InputField( - description=FieldDescriptions.vae, - input=Input.Connection, - ) - # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) - - if self.vae.vae.submodel != SubModelType.VaeDecoder: - raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") - - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - ) - - # clear memory as vae decode can request a lot - torch.cuda.empty_cache() - - with vae_info as vae: - vae.create_session() - - # copied from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427 - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - image = VaeImageProcessor.numpy_to_pil(image)[0] - - torch.cuda.empty_cache() - - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) - - -@invocation_output("model_loader_output_onnx") -class ONNXModelLoaderOutput(BaseInvocationOutput): - """Model loader output""" - - unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") - clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") - vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder") - vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder") - - -class OnnxModelField(BaseModel): - """Onnx model field""" - - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) - - -@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0") -class OnnxModelLoaderInvocation(BaseInvocation): - """Loads a main model, outputting its submodels.""" - - model: OnnxModelField = InputField( - description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel - ) - - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.ONNX - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ - - return ONNXModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, - ), - loras=[], - skipped_layers=0, - ), - vae_decoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeDecoder, - ), - ), - vae_encoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeEncoder, - ), - ), - ) From de0b72528cbb3f7ffcdafcad5c3232e81e748fd2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:41:23 +1100 Subject: [PATCH 34/67] feat(nodes): replace latents service with tensors and conditioning services - New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling - Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk` - Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices` - Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices` - Remove `latents` service and all `LatentsStorage` classes - Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods --- invokeai/app/api/dependencies.py | 18 +++-- invokeai/app/invocations/latent.py | 36 +++++----- invokeai/app/invocations/noise.py | 2 +- invokeai/app/invocations/primitives.py | 6 +- .../invocation_cache_memory.py | 3 +- invokeai/app/services/invocation_services.py | 12 +++- .../app/services/latents_storage/__init__.py | 0 .../latents_storage/latents_storage_disk.py | 58 ---------------- .../latents_storage_forward_cache.py | 68 ------------------- .../pickle_storage_base.py} | 18 ++--- .../pickle_storage_forward_cache.py | 58 ++++++++++++++++ .../pickle_storage/pickle_storage_torch.py | 62 +++++++++++++++++ .../app/services/shared/invocation_context.py | 49 ++++++------- 13 files changed, 197 insertions(+), 193 deletions(-) delete mode 100644 invokeai/app/services/latents_storage/__init__.py delete mode 100644 invokeai/app/services/latents_storage/latents_storage_disk.py delete mode 100644 invokeai/app/services/latents_storage/latents_storage_forward_cache.py rename invokeai/app/services/{latents_storage/latents_storage_base.py => pickle_storage/pickle_storage_base.py} (68%) create mode 100644 invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py create mode 100644 invokeai/app/services/pickle_storage/pickle_storage_torch.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index c8309e1729e..6bb0915cb6e 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,9 +2,14 @@ from logging import Logger +import torch + from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory +from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache +from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -23,8 +28,6 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker -from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage -from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_records import ModelRecordServiceSQL @@ -68,6 +71,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger logger.debug(f"Internet connectivity is {config.internet_available}") output_folder = config.output_path + if output_folder is None: + raise ValueError("Output folder is not set") + image_files = DiskImageFileStorage(f"{output_folder}/images") db = init_db(config=config, logger=logger, image_files=image_files) @@ -84,7 +90,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) + tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor")) + conditioning = PickleStorageForwardCache( + PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning") + ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) download_queue_service = DownloadQueueService(event_bus=events) @@ -117,7 +126,6 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records=image_records, images=images, invocation_cache=invocation_cache, - latents=latents, logger=logger, model_manager=model_manager, model_records=model_record_service, @@ -131,6 +139,8 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger session_queue=session_queue, urls=urls, workflow_records=workflow_records, + tensors=tensors, + conditioning=conditioning, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5449ec9af7a..94440d3e2aa 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -163,11 +163,11 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = context.latents.save(tensor=masked_latents) + masked_latents_name = context.tensors.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = context.latents.save(tensor=mask) + mask_name = context.tensors.save(tensor=mask) return DenoiseMaskOutput.build( mask_name=mask_name, @@ -621,10 +621,10 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None - mask = context.latents.get(self.denoise_mask.mask_name) + mask = context.tensors.get(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name) else: masked_latents = None @@ -636,11 +636,11 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: seed = None noise = None if self.noise is not None: - noise = context.latents.get(self.noise.latents_name) + noise = context.tensors.get(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -752,7 +752,7 @@ def _lora_loader(): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=result_latents) + name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -888,7 +888,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=resized_latents) + name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -930,7 +930,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=resized_latents) + name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -1011,7 +1011,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) latents = latents.to("cpu") - name = context.latents.save(tensor=latents) + name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation): alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.latents.get(self.latents_a.latents_name) - latents_b = context.latents.get(self.latents_b.latents_name) + latents_a = context.tensors.get(self.latents_a.latents_name) + latents_b = context.tensors.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1103,7 +1103,7 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=blended_latents) + name = context.tensors.save(tensor=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents) @@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1158,7 +1158,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: cropped_latents = latents[..., y1:y2, x1:x2] - name = context.latents.save(tensor=cropped_latents) + name = context.tensors.save(tensor=cropped_latents) return LatentsOutput.build(latents_name=name, latents=cropped_latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 78f13cc52d1..74b3d6e4cb1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -121,5 +121,5 @@ def invoke(self, context: InvocationContext) -> NoiseOutput: seed=self.seed, use_cpu=self.use_cpu, ) - name = context.latents.save(tensor=noise) + name = context.tensors.save(tensor=noise) return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index a77939943ae..082d5432ccf 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -313,9 +313,7 @@ def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "De class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" - latents: LatentsField = OutputField( - description=FieldDescriptions.latents, - ) + latents: LatentsField = OutputField(description=FieldDescriptions.latents) width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) @@ -346,7 +344,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 4a503b3c6b1..c700f81186f 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -37,7 +37,8 @@ def start(self, invoker: Invoker) -> None: if self._max_cache_size == 0: return self._invoker.services.images.on_deleted(self._delete_by_match) - self._invoker.services.latents.on_deleted(self._delete_by_match) + self._invoker.services.tensors.on_deleted(self._delete_by_match) + self._invoker.services.conditioning.on_deleted(self._delete_by_match) def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: with self._lock: diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 51bfd5d77a1..81885781acb 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -6,6 +6,10 @@ if TYPE_CHECKING: from logging import Logger + import torch + + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData + from .board_image_records.board_image_records_base import BoardImageRecordStorageBase from .board_images.board_images_base import BoardImagesServiceABC from .board_records.board_records_base import BoardRecordStorageBase @@ -21,11 +25,11 @@ from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC - from .latents_storage.latents_storage_base import LatentsStorageBase from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase + from .pickle_storage.pickle_storage_base import PickleStorageBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState @@ -48,7 +52,6 @@ def __init__( images: "ImageServiceABC", image_files: "ImageFileStorageBase", image_records: "ImageRecordStorageBase", - latents: "LatentsStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", model_records: "ModelRecordServiceBase", @@ -63,6 +66,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", + tensors: "PickleStorageBase[torch.Tensor]", + conditioning: "PickleStorageBase[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records @@ -74,7 +79,6 @@ def __init__( self.images = images self.image_files = image_files self.image_records = image_records - self.latents = latents self.logger = logger self.model_manager = model_manager self.model_records = model_records @@ -89,3 +93,5 @@ def __init__( self.names = names self.urls = urls self.workflow_records = workflow_records + self.tensors = tensors + self.conditioning = conditioning diff --git a/invokeai/app/services/latents_storage/__init__.py b/invokeai/app/services/latents_storage/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/invokeai/app/services/latents_storage/latents_storage_disk.py b/invokeai/app/services/latents_storage/latents_storage_disk.py deleted file mode 100644 index 9192b9147f7..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_disk.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import Union - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class DiskLatentsStorage(LatentsStorageBase): - """Stores latents in a folder on disk without caching""" - - __output_folder: Path - - def __init__(self, output_folder: Union[str, Path]): - self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__output_folder.mkdir(parents=True, exist_ok=True) - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_latents() - - def get(self, name: str) -> torch.Tensor: - latent_path = self.get_path(name) - return torch.load(latent_path) - - def save(self, name: str, data: torch.Tensor) -> None: - self.__output_folder.mkdir(parents=True, exist_ok=True) - latent_path = self.get_path(name) - torch.save(data, latent_path) - - def delete(self, name: str) -> None: - latent_path = self.get_path(name) - latent_path.unlink() - - def get_path(self, name: str) -> Path: - return self.__output_folder / name - - def _delete_all_latents(self) -> None: - """ - Deletes all latents from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - deleted_latents_count = 0 - freed_space = 0 - for latents_file in Path(self.__output_folder).glob("*"): - if latents_file.is_file(): - freed_space += latents_file.stat().st_size - deleted_latents_count += 1 - latents_file.unlink() - if deleted_latents_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py deleted file mode 100644 index 6232b76a27d..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from queue import Queue -from typing import Dict, Optional - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class ForwardCacheLatentsStorage(LatentsStorageBase): - """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" - - __cache: Dict[str, torch.Tensor] - __cache_ids: Queue - __max_cache_size: int - __underlying_storage: LatentsStorageBase - - def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): - super().__init__() - self.__underlying_storage = underlying_storage - self.__cache = {} - self.__cache_ids = Queue() - self.__max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self.__underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self.__underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> torch.Tensor: - cache_item = self.__get_cache(name) - if cache_item is not None: - return cache_item - - latent = self.__underlying_storage.get(name) - self.__set_cache(name, latent) - return latent - - def save(self, name: str, data: torch.Tensor) -> None: - self.__underlying_storage.save(name, data) - self.__set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self.__underlying_storage.delete(name) - if name in self.__cache: - del self.__cache[name] - self._on_deleted(name) - - def __get_cache(self, name: str) -> Optional[torch.Tensor]: - return None if name not in self.__cache else self.__cache[name] - - def __set_cache(self, name: str, data: torch.Tensor): - if name not in self.__cache: - self.__cache[name] = data - self.__cache_ids.put(name) - if self.__cache_ids.qsize() > self.__max_cache_size: - self.__cache.pop(self.__cache_ids.get()) diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/pickle_storage/pickle_storage_base.py similarity index 68% rename from invokeai/app/services/latents_storage/latents_storage_base.py rename to invokeai/app/services/pickle_storage/pickle_storage_base.py index 9fa42b0ae61..558b97c0f1b 100644 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_base.py @@ -1,15 +1,15 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Generic, TypeVar -import torch +T = TypeVar("T") -class LatentsStorageBase(ABC): - """Responsible for storing and retrieving latents.""" +class PickleStorageBase(ABC, Generic[T]): + """Responsible for storing and retrieving non-serializable data using a pickler.""" - _on_changed_callbacks: list[Callable[[torch.Tensor], None]] + _on_changed_callbacks: list[Callable[[T], None]] _on_deleted_callbacks: list[Callable[[str], None]] def __init__(self) -> None: @@ -17,18 +17,18 @@ def __init__(self) -> None: self._on_deleted_callbacks = [] @abstractmethod - def get(self, name: str) -> torch.Tensor: + def get(self, name: str) -> T: pass @abstractmethod - def save(self, name: str, data: torch.Tensor) -> None: + def save(self, name: str, data: T) -> None: pass @abstractmethod def delete(self, name: str) -> None: pass - def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None: + def on_changed(self, on_changed: Callable[[T], None]) -> None: """Register a callback for when an item is changed""" self._on_changed_callbacks.append(on_changed) @@ -36,7 +36,7 @@ def on_deleted(self, on_deleted: Callable[[str], None]) -> None: """Register a callback for when an item is deleted""" self._on_deleted_callbacks.append(on_deleted) - def _on_changed(self, item: torch.Tensor) -> None: + def _on_changed(self, item: T) -> None: for callback in self._on_changed_callbacks: callback(item) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py new file mode 100644 index 00000000000..3002d9e045d --- /dev/null +++ b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py @@ -0,0 +1,58 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase + +T = TypeVar("T") + + +class PickleStorageForwardCache(PickleStorageBase[T]): + def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def get(self, name: str) -> T: + cache_item = self._get_cache(name) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.get(name) + self._set_cache(name, latent) + return latent + + def save(self, name: str, data: T) -> None: + self._underlying_storage.save(name, data) + self._set_cache(name, data) + self._on_changed(data) + + def delete(self, name: str) -> None: + self._underlying_storage.delete(name) + if name in self._cache: + del self._cache[name] + self._on_deleted(name) + + def _get_cache(self, name: str) -> Optional[T]: + return None if name not in self._cache else self._cache[name] + + def _set_cache(self, name: str, data: T): + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py new file mode 100644 index 00000000000..0b3c9af7a33 --- /dev/null +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from pathlib import Path +from typing import TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase + +T = TypeVar("T") + + +class PickleStorageTorch(PickleStorageBase[T]): + """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" + + def __init__(self, output_folder: Path, item_type_name: "str"): + self._output_folder = output_folder + self._output_folder.mkdir(parents=True, exist_ok=True) + self._item_type_name = item_type_name + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all_items() + + def get(self, name: str) -> T: + latent_path = self._get_path(name) + return torch.load(latent_path) + + def save(self, name: str, data: T) -> None: + self._output_folder.mkdir(parents=True, exist_ok=True) + latent_path = self._get_path(name) + torch.save(data, latent_path) + + def delete(self, name: str) -> None: + latent_path = self._get_path(name) + latent_path.unlink() + + def _get_path(self, name: str) -> Path: + return self._output_folder / name + + def _delete_all_items(self) -> None: + """ + Deletes all pickled items from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_latents_count = 0 + freed_space = 0 + for latents_file in Path(self._output_folder).glob("*"): + if latents_file.is_file(): + freed_space += latents_file.stat().st_size + deleted_latents_count += 1 + latents_file.unlink() + if deleted_latents_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 97a62246fbc..6756b1f5c6c 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -216,48 +216,46 @@ def get_dto(self, image_name: str) -> ImageDTO: return self._services.images.get_dto(image_name) -class LatentsInterface(InvocationContextInterface): +class TensorsInterface(InvocationContextInterface): def save(self, tensor: Tensor) -> str: """ - Saves a latents tensor, returning its name. + Saves a tensor, returning its name. - :param tensor: The latents tensor to save. + :param tensor: The tensor to save. """ # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. # "mask", "noise", "masked_latents", etc. # # Retaining that capability in this wrapper would require either many different methods - # to save latents, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all latents. + # to save tensors, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all tensors. # # This has a very minor impact as we don't use them after a session completes. - # Previously, invocations chose the name for their latents. This is a bit risky, so we + # Previously, invocations chose the name for their tensors. This is a bit risky, so we # will generate a name for them instead. We use a uuid to ensure the name is unique. # - # Because the name of the latents file will includes the session and invocation IDs, + # Because the name of the tensors file will includes the session and invocation IDs, # we don't need to worry about collisions. A truncated UUIDv4 is fine. name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.latents.save( + self._services.tensors.save( name=name, data=tensor, ) return name - def get(self, latents_name: str) -> Tensor: + def get(self, tensor_name: str) -> Tensor: """ - Gets a latents tensor by name. + Gets a tensor by name. - :param latents_name: The name of the latents tensor to get. + :param tensor_name: The name of the tensor to get. """ - return self._services.latents.get(latents_name) + return self._services.tensors.get(tensor_name) class ConditioningInterface(InvocationContextInterface): - # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed to work with Tensors only. We have to fudge the types here. def save(self, conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. @@ -265,15 +263,12 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. - # - # See comment for `LatentsInterface.save` for more info about this method (it's very - # similar). + # See comment in TensorsInterface.save for why we generate the name here. - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" - self._services.latents.save( + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" + self._services.conditioning.save( name=name, - data=conditioning_data, # type: ignore [arg-type] + data=conditioning_data, ) return name @@ -284,7 +279,7 @@ def get(self, conditioning_name: str) -> ConditioningFieldData: :param conditioning_name: The name of the conditioning data to get. """ - return self._services.latents.get(conditioning_name) # type: ignore [return-value] + return self._services.conditioning.get(conditioning_name) class ModelsInterface(InvocationContextInterface): @@ -400,7 +395,7 @@ class InvocationContext: def __init__( self, images: ImagesInterface, - latents: LatentsInterface, + tensors: TensorsInterface, conditioning: ConditioningInterface, models: ModelsInterface, logger: LoggerInterface, @@ -412,8 +407,8 @@ def __init__( ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" - self.latents = latents - """Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" + self.tensors = tensors + """Provides methods to save and get tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning """Provides methods to save and get conditioning data.""" self.models = models @@ -532,7 +527,7 @@ def build_invocation_context( logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) - latents = LatentsInterface(services=services, context_data=context_data) + tensors = TensorsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data) @@ -543,7 +538,7 @@ def build_invocation_context( images=images, logger=logger, config=config, - latents=latents, + tensors=tensors, models=models, context_data=context_data, util=util, From a7f91b3e01e2f9970cbf352ea3f9b8f85d76247b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:17:23 +1100 Subject: [PATCH 35/67] tidy(nodes): do not refer to files as latents in `PickleStorageTorch` --- .../pickle_storage/pickle_storage_torch.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index 0b3c9af7a33..7b18dc0625e 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -48,15 +48,15 @@ def _delete_all_items(self) -> None: if not self._invoker: raise ValueError("Invoker is not set. Must call `start()` first.") - deleted_latents_count = 0 + deleted_count = 0 freed_space = 0 - for latents_file in Path(self._output_folder).glob("*"): - if latents_file.is_file(): - freed_space += latents_file.stat().st_size - deleted_latents_count += 1 - latents_file.unlink() - if deleted_latents_count > 0: + for file in Path(self._output_folder).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: freed_space_in_mb = round(freed_space / 1024 / 1024, 2) self._invoker.services.logger.info( - f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" + f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" ) From 0266946d3d86fd4d85d7088a93faa7945b1c6144 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:35:58 +1100 Subject: [PATCH 36/67] fix(nodes): add super init to `PickleStorageTorch` --- invokeai/app/services/pickle_storage/pickle_storage_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index 7b18dc0625e..de411bbf47d 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -15,6 +15,7 @@ class PickleStorageTorch(PickleStorageBase[T]): """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" def __init__(self, output_folder: Path, item_type_name: "str"): + super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) self._item_type_name = item_type_name From fd30cb4d9057c5e7244df0c8e157128e6a2188f6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:43:33 +1100 Subject: [PATCH 37/67] feat(nodes): ItemStorageABC typevar no longer bound to pydantic.BaseModel This bound is totally unnecessary. There's no requirement for any implementation of `ItemStorageABC` to work only on pydantic models. --- invokeai/app/services/item_storage/item_storage_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index c93edf5188d..d7366791594 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -1,9 +1,7 @@ from abc import ABC, abstractmethod from typing import Callable, Generic, TypeVar -from pydantic import BaseModel - -T = TypeVar("T", bound=BaseModel) +T = TypeVar("T") class ItemStorageABC(ABC, Generic[T]): From 25386a76ef52e7e6eafa524408cf6bb33e1ad356 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:44:30 +1100 Subject: [PATCH 38/67] tidy(nodes): do not refer to files as latents in `PickleStorageTorch` (again) --- .../services/pickle_storage/pickle_storage_torch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index de411bbf47d..16f0d7bb7ad 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -25,17 +25,17 @@ def start(self, invoker: Invoker) -> None: self._delete_all_items() def get(self, name: str) -> T: - latent_path = self._get_path(name) - return torch.load(latent_path) + file_path = self._get_path(name) + return torch.load(file_path) def save(self, name: str, data: T) -> None: self._output_folder.mkdir(parents=True, exist_ok=True) - latent_path = self._get_path(name) - torch.save(data, latent_path) + file_path = self._get_path(name) + torch.save(data, file_path) def delete(self, name: str) -> None: - latent_path = self._get_path(name) - latent_path.unlink() + file_path = self._get_path(name) + file_path.unlink() def _get_path(self, name: str) -> Path: return self._output_folder / name From fe0391c86bd819536c4f37f384792c66a1709883 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:39:03 +1100 Subject: [PATCH 39/67] feat(nodes): use `ItemStorageABC` for tensors and conditioning Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both. There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution. The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items. --- invokeai/app/api/dependencies.py | 10 +-- invokeai/app/services/invocation_services.py | 5 +- .../item_storage/item_storage_base.py | 2 +- .../item_storage_ephemeral_disk.py | 72 +++++++++++++++++++ .../item_storage_forward_cache.py | 61 ++++++++++++++++ .../item_storage/item_storage_memory.py | 3 +- .../pickle_storage/pickle_storage_base.py | 45 ------------ .../pickle_storage_forward_cache.py | 58 --------------- .../pickle_storage/pickle_storage_torch.py | 63 ---------------- .../app/services/shared/invocation_context.py | 30 +------- 10 files changed, 145 insertions(+), 204 deletions(-) create mode 100644 invokeai/app/services/item_storage/item_storage_ephemeral_disk.py create mode 100644 invokeai/app/services/item_storage/item_storage_forward_cache.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_base.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_torch.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 6bb0915cb6e..d6fd970a22d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,9 +4,9 @@ import torch +from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk +from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory -from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache -from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -90,9 +90,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor")) - conditioning = PickleStorageForwardCache( - PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning") + tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors")) + conditioning = ItemStorageForwardCache( + ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 81885781acb..69599d83a4b 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -29,7 +29,6 @@ from .model_manager.model_manager_base import ModelManagerServiceBase from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase - from .pickle_storage.pickle_storage_base import PickleStorageBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState @@ -66,8 +65,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", - tensors: "PickleStorageBase[torch.Tensor]", - conditioning: "PickleStorageBase[ConditioningFieldData]", + tensors: "ItemStorageABC[torch.Tensor]", + conditioning: "ItemStorageABC[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index d7366791594..f2d62ea45fb 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -26,7 +26,7 @@ def get(self, item_id: str) -> T: pass @abstractmethod - def set(self, item: T) -> None: + def set(self, item: T) -> str: """ Sets the item. The id will be extracted based on id_field. :param item: the item to set diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py new file mode 100644 index 00000000000..9843d1e54bc --- /dev/null +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -0,0 +1,72 @@ +import typing +from pathlib import Path +from typing import Optional, TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC +from invokeai.app.util.misc import uuid_string + +T = TypeVar("T") + + +class ItemStorageEphemeralDisk(ItemStorageABC[T]): + """Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup.""" + + def __init__(self, output_folder: Path): + super().__init__() + self._output_folder = output_folder + self._output_folder.mkdir(parents=True, exist_ok=True) + self.__item_class_name: Optional[str] = None + + @property + def _item_class_name(self) -> str: + if not self.__item_class_name: + # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason + self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + return self.__item_class_name + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all_items() + + def get(self, item_id: str) -> T: + file_path = self._get_path(item_id) + return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + + def set(self, item: T) -> str: + self._output_folder.mkdir(parents=True, exist_ok=True) + item_id = f"{self._item_class_name}_{uuid_string()}" + file_path = self._get_path(item_id) + torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] + return item_id + + def delete(self, item_id: str) -> None: + file_path = self._get_path(item_id) + file_path.unlink() + + def _get_path(self, item_id: str) -> Path: + return self._output_folder / item_id + + def _delete_all_items(self) -> None: + """ + Deletes all pickled items from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_count = 0 + freed_space = 0 + for file in Path(self._output_folder).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/item_storage/item_storage_forward_cache.py b/invokeai/app/services/item_storage/item_storage_forward_cache.py new file mode 100644 index 00000000000..d1fe8e13fa9 --- /dev/null +++ b/invokeai/app/services/item_storage/item_storage_forward_cache.py @@ -0,0 +1,61 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC + +T = TypeVar("T") + + +class ItemStorageForwardCache(ItemStorageABC[T]): + """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + + def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def get(self, item_id: str) -> T: + cache_item = self._get_cache(item_id) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.get(item_id) + self._set_cache(item_id, latent) + return latent + + def set(self, item: T) -> str: + item_id = self._underlying_storage.set(item) + self._set_cache(item_id, item) + self._on_changed(item) + return item_id + + def delete(self, item_id: str) -> None: + self._underlying_storage.delete(item_id) + if item_id in self._cache: + del self._cache[item_id] + self._on_deleted(item_id) + + def _get_cache(self, item_id: str) -> Optional[T]: + return None if item_id not in self._cache else self._cache[item_id] + + def _set_cache(self, item_id: str, data: T): + if item_id not in self._cache: + self._cache[item_id] = data + self._cache_ids.put(item_id) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/item_storage/item_storage_memory.py b/invokeai/app/services/item_storage/item_storage_memory.py index d8dd0e06645..6d028745164 100644 --- a/invokeai/app/services/item_storage/item_storage_memory.py +++ b/invokeai/app/services/item_storage/item_storage_memory.py @@ -34,7 +34,7 @@ def get(self, item_id: str) -> T: self._items[item_id] = item return item - def set(self, item: T) -> None: + def set(self, item: T) -> str: item_id = getattr(item, self._id_field) if item_id in self._items: # If item already exists, remove it and add it to the end @@ -44,6 +44,7 @@ def set(self, item: T) -> None: self._items.popitem(last=False) self._items[item_id] = item self._on_changed(item) + return item_id def delete(self, item_id: str) -> None: # This is a no-op if the item doesn't exist. diff --git a/invokeai/app/services/pickle_storage/pickle_storage_base.py b/invokeai/app/services/pickle_storage/pickle_storage_base.py deleted file mode 100644 index 558b97c0f1b..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_base.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from typing import Callable, Generic, TypeVar - -T = TypeVar("T") - - -class PickleStorageBase(ABC, Generic[T]): - """Responsible for storing and retrieving non-serializable data using a pickler.""" - - _on_changed_callbacks: list[Callable[[T], None]] - _on_deleted_callbacks: list[Callable[[str], None]] - - def __init__(self) -> None: - self._on_changed_callbacks = [] - self._on_deleted_callbacks = [] - - @abstractmethod - def get(self, name: str) -> T: - pass - - @abstractmethod - def save(self, name: str, data: T) -> None: - pass - - @abstractmethod - def delete(self, name: str) -> None: - pass - - def on_changed(self, on_changed: Callable[[T], None]) -> None: - """Register a callback for when an item is changed""" - self._on_changed_callbacks.append(on_changed) - - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: - """Register a callback for when an item is deleted""" - self._on_deleted_callbacks.append(on_deleted) - - def _on_changed(self, item: T) -> None: - for callback in self._on_changed_callbacks: - callback(item) - - def _on_deleted(self, item_id: str) -> None: - for callback in self._on_deleted_callbacks: - callback(item_id) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py deleted file mode 100644 index 3002d9e045d..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py +++ /dev/null @@ -1,58 +0,0 @@ -from queue import Queue -from typing import Optional, TypeVar - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase - -T = TypeVar("T") - - -class PickleStorageForwardCache(PickleStorageBase[T]): - def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20): - super().__init__() - self._underlying_storage = underlying_storage - self._cache: dict[str, T] = {} - self._cache_ids = Queue[str]() - self._max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self._underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self._underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> T: - cache_item = self._get_cache(name) - if cache_item is not None: - return cache_item - - latent = self._underlying_storage.get(name) - self._set_cache(name, latent) - return latent - - def save(self, name: str, data: T) -> None: - self._underlying_storage.save(name, data) - self._set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self._underlying_storage.delete(name) - if name in self._cache: - del self._cache[name] - self._on_deleted(name) - - def _get_cache(self, name: str) -> Optional[T]: - return None if name not in self._cache else self._cache[name] - - def _set_cache(self, name: str, data: T): - if name not in self._cache: - self._cache[name] = data - self._cache_ids.put(name) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py deleted file mode 100644 index 16f0d7bb7ad..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import TypeVar - -import torch - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase - -T = TypeVar("T") - - -class PickleStorageTorch(PickleStorageBase[T]): - """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" - - def __init__(self, output_folder: Path, item_type_name: "str"): - super().__init__() - self._output_folder = output_folder - self._output_folder.mkdir(parents=True, exist_ok=True) - self._item_type_name = item_type_name - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_items() - - def get(self, name: str) -> T: - file_path = self._get_path(name) - return torch.load(file_path) - - def save(self, name: str, data: T) -> None: - self._output_folder.mkdir(parents=True, exist_ok=True) - file_path = self._get_path(name) - torch.save(data, file_path) - - def delete(self, name: str) -> None: - file_path = self._get_path(name) - file_path.unlink() - - def _get_path(self, name: str) -> Path: - return self._output_folder / name - - def _delete_all_items(self) -> None: - """ - Deletes all pickled items from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_folder).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 6756b1f5c6c..baff47a3df4 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -12,7 +12,6 @@ from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.util.misc import uuid_string from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.model_manager import ModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType @@ -224,26 +223,7 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. - # "mask", "noise", "masked_latents", etc. - # - # Retaining that capability in this wrapper would require either many different methods - # to save tensors, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all tensors. - # - # This has a very minor impact as we don't use them after a session completes. - - # Previously, invocations chose the name for their tensors. This is a bit risky, so we - # will generate a name for them instead. We use a uuid to ensure the name is unique. - # - # Because the name of the tensors file will includes the session and invocation IDs, - # we don't need to worry about collisions. A truncated UUIDv4 is fine. - - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.tensors.save( - name=name, - data=tensor, - ) + name = self._services.tensors.set(item=tensor) return name def get(self, tensor_name: str) -> Tensor: @@ -263,13 +243,7 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - # See comment in TensorsInterface.save for why we generate the name here. - - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.conditioning.save( - name=name, - data=conditioning_data, - ) + name = self._services.conditioning.set(item=conditioning_data) return name def get(self, conditioning_name: str) -> ConditioningFieldData: From 7fe5283e7400b3f181a200367f3163ee9d8db637 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:50:30 +1100 Subject: [PATCH 40/67] feat(nodes): create helper function to generate the item ID --- .../app/services/item_storage/item_storage_ephemeral_disk.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 9843d1e54bc..377c9c39b3e 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -37,7 +37,7 @@ def get(self, item_id: str) -> T: def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) - item_id = f"{self._item_class_name}_{uuid_string()}" + item_id = self._new_item_id() file_path = self._get_path(item_id) torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] return item_id @@ -49,6 +49,9 @@ def delete(self, item_id: str) -> None: def _get_path(self, item_id: str) -> Path: return self._output_folder / item_id + def _new_item_id(self) -> str: + return f"{self._item_class_name}_{uuid_string()}" + def _delete_all_items(self) -> None: """ Deletes all pickled items from disk. From 54a67459bf1b6d5ea1a21d2f8cdb8533ed05a7c6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:51:04 +1100 Subject: [PATCH 41/67] feat(nodes): support custom save and load functions in `ItemStorageEphemeralDisk` --- .../item_storage/item_storage_common.py | 10 +++++++ .../item_storage_ephemeral_disk.py | 26 +++++++++++++++---- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_common.py b/invokeai/app/services/item_storage/item_storage_common.py index 8fd677c71b7..7f9bd7bd4ef 100644 --- a/invokeai/app/services/item_storage/item_storage_common.py +++ b/invokeai/app/services/item_storage/item_storage_common.py @@ -1,5 +1,15 @@ +from pathlib import Path +from typing import Callable, TypeAlias, TypeVar + + class ItemNotFoundError(KeyError): """Raised when an item is not found in storage""" def __init__(self, item_id: str) -> None: super().__init__(f"Item with id {item_id} not found") + + +T = TypeVar("T") + +SaveFunc: TypeAlias = Callable[[T, Path], None] +LoadFunc: TypeAlias = Callable[[Path], T] diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 377c9c39b3e..4dc67129dac 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -6,18 +6,31 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC +from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc from invokeai.app.util.misc import uuid_string T = TypeVar("T") class ItemStorageEphemeralDisk(ItemStorageABC[T]): - """Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup.""" - - def __init__(self, output_folder: Path): + """Provides a disk-backed ephemeral storage. The storage is cleared at startup. + + :param output_folder: The folder where the items will be stored + :param save: The function to use to save the items to disk [torch.save] + :param load: The function to use to load the items from disk [torch.load] + """ + + def __init__( + self, + output_folder: Path, + save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] + load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] + ): super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) + self._save = save + self._load = load self.__item_class_name: Optional[str] = None @property @@ -33,13 +46,13 @@ def start(self, invoker: Invoker) -> None: def get(self, item_id: str) -> T: file_path = self._get_path(item_id) - return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + return self._load(file_path) def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) item_id = self._new_item_id() file_path = self._get_path(item_id) - torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] + self._save(item, file_path) return item_id def delete(self, item_id: str) -> None: @@ -58,6 +71,9 @@ def _delete_all_items(self) -> None: Must be called after we have access to `self._invoker` (e.g. in `start()`). """ + # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have + # to manually clear them on startup anyways. This is a bit simpler and more reliable. + if not self._invoker: raise ValueError("Invoker is not set. Must call `start()` first.") From 8b6e32269702a837d4f785cf56ae545f69df43ef Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 22:54:52 +1100 Subject: [PATCH 42/67] feat(nodes): support custom exception in ephemeral disk storage --- .../item_storage/item_storage_ephemeral_disk.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 4dc67129dac..97c767c87d7 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -1,12 +1,12 @@ import typing from pathlib import Path -from typing import Optional, TypeVar +from typing import Optional, Type, TypeVar import torch from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC -from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc +from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc from invokeai.app.util.misc import uuid_string T = TypeVar("T") @@ -18,6 +18,7 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]): :param output_folder: The folder where the items will be stored :param save: The function to use to save the items to disk [torch.save] :param load: The function to use to load the items from disk [torch.load] + :param load_exc: The exception that is raised when an item is not found [FileNotFoundError] """ def __init__( @@ -25,12 +26,14 @@ def __init__( output_folder: Path, save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] + load_exc: Type[Exception] = FileNotFoundError, ): super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) self._save = save self._load = load + self._load_exc = load_exc self.__item_class_name: Optional[str] = None @property @@ -46,7 +49,10 @@ def start(self, invoker: Invoker) -> None: def get(self, item_id: str) -> T: file_path = self._get_path(item_id) - return self._load(file_path) + try: + return self._load(file_path) + except self._load_exc as e: + raise ItemNotFoundError(item_id) from e def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) From 06429028c8015d050fe3cb87951c6362e8702ac4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:30:46 +1100 Subject: [PATCH 43/67] revert(nodes): revert making tensors/conditioning use item storage Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning. Refined the class a bit too. --- invokeai/app/api/dependencies.py | 10 +- invokeai/app/invocations/latent.py | 24 ++--- invokeai/app/invocations/primitives.py | 2 +- invokeai/app/services/invocation_services.py | 6 +- .../item_storage/item_storage_base.py | 8 +- .../item_storage/item_storage_common.py | 10 -- .../item_storage_ephemeral_disk.py | 97 ------------------- .../item_storage_forward_cache.py | 61 ------------ .../item_storage/item_storage_memory.py | 3 +- .../object_serializer_base.py | 53 ++++++++++ .../object_serializer_common.py | 5 + .../object_serializer_ephemeral_disk.py | 84 ++++++++++++++++ .../object_serializer_forward_cache.py | 61 ++++++++++++ .../app/services/shared/invocation_context.py | 24 ++--- 14 files changed, 243 insertions(+), 205 deletions(-) delete mode 100644 invokeai/app/services/item_storage/item_storage_ephemeral_disk.py delete mode 100644 invokeai/app/services/item_storage/item_storage_forward_cache.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_base.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_common.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_forward_cache.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index d6fd970a22d..0c80494616f 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,9 +4,9 @@ import torch -from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk -from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory +from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -90,9 +90,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors")) - conditioning = ItemStorageForwardCache( - ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") + tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors")) + conditioning = ObjectSerializerForwardCache( + ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 94440d3e2aa..4137ab6e2f6 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -304,11 +304,11 @@ def get_conditioning_data( unet, seed, ) -> ConditioningData: - positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -621,10 +621,10 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None - mask = context.tensors.get(self.denoise_mask.mask_name) + mask = context.tensors.load(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name) + masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name) else: masked_latents = None @@ -636,11 +636,11 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: seed = None noise = None if self.noise is not None: - noise = context.tensors.get(self.noise.latents_name) + noise = context.tensors.load(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation): alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.tensors.get(self.latents_a.latents_name) - latents_b = context.tensors.get(self.latents_b.latents_name) + latents_a = context.tensors.load(self.latents_a.latents_name) + latents_b = context.tensors.load(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 082d5432ccf..d0f95c92d02 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -344,7 +344,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 69599d83a4b..e893be87636 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase + if TYPE_CHECKING: from logging import Logger @@ -65,8 +67,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", - tensors: "ItemStorageABC[torch.Tensor]", - conditioning: "ItemStorageABC[ConditioningFieldData]", + tensors: "ObjectSerializerBase[torch.Tensor]", + conditioning: "ObjectSerializerBase[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index f2d62ea45fb..ef227ba241c 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod from typing import Callable, Generic, TypeVar -T = TypeVar("T") +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) class ItemStorageABC(ABC, Generic[T]): @@ -26,9 +28,9 @@ def get(self, item_id: str) -> T: pass @abstractmethod - def set(self, item: T) -> str: + def set(self, item: T) -> None: """ - Sets the item. The id will be extracted based on id_field. + Sets the item. :param item: the item to set """ pass diff --git a/invokeai/app/services/item_storage/item_storage_common.py b/invokeai/app/services/item_storage/item_storage_common.py index 7f9bd7bd4ef..8fd677c71b7 100644 --- a/invokeai/app/services/item_storage/item_storage_common.py +++ b/invokeai/app/services/item_storage/item_storage_common.py @@ -1,15 +1,5 @@ -from pathlib import Path -from typing import Callable, TypeAlias, TypeVar - - class ItemNotFoundError(KeyError): """Raised when an item is not found in storage""" def __init__(self, item_id: str) -> None: super().__init__(f"Item with id {item_id} not found") - - -T = TypeVar("T") - -SaveFunc: TypeAlias = Callable[[T, Path], None] -LoadFunc: TypeAlias = Callable[[Path], T] diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py deleted file mode 100644 index 97c767c87d7..00000000000 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ /dev/null @@ -1,97 +0,0 @@ -import typing -from pathlib import Path -from typing import Optional, Type, TypeVar - -import torch - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC -from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc -from invokeai.app.util.misc import uuid_string - -T = TypeVar("T") - - -class ItemStorageEphemeralDisk(ItemStorageABC[T]): - """Provides a disk-backed ephemeral storage. The storage is cleared at startup. - - :param output_folder: The folder where the items will be stored - :param save: The function to use to save the items to disk [torch.save] - :param load: The function to use to load the items from disk [torch.load] - :param load_exc: The exception that is raised when an item is not found [FileNotFoundError] - """ - - def __init__( - self, - output_folder: Path, - save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] - load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] - load_exc: Type[Exception] = FileNotFoundError, - ): - super().__init__() - self._output_folder = output_folder - self._output_folder.mkdir(parents=True, exist_ok=True) - self._save = save - self._load = load - self._load_exc = load_exc - self.__item_class_name: Optional[str] = None - - @property - def _item_class_name(self) -> str: - if not self.__item_class_name: - # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason - self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] - return self.__item_class_name - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_items() - - def get(self, item_id: str) -> T: - file_path = self._get_path(item_id) - try: - return self._load(file_path) - except self._load_exc as e: - raise ItemNotFoundError(item_id) from e - - def set(self, item: T) -> str: - self._output_folder.mkdir(parents=True, exist_ok=True) - item_id = self._new_item_id() - file_path = self._get_path(item_id) - self._save(item, file_path) - return item_id - - def delete(self, item_id: str) -> None: - file_path = self._get_path(item_id) - file_path.unlink() - - def _get_path(self, item_id: str) -> Path: - return self._output_folder / item_id - - def _new_item_id(self) -> str: - return f"{self._item_class_name}_{uuid_string()}" - - def _delete_all_items(self) -> None: - """ - Deletes all pickled items from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - - # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have - # to manually clear them on startup anyways. This is a bit simpler and more reliable. - - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_folder).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/item_storage/item_storage_forward_cache.py b/invokeai/app/services/item_storage/item_storage_forward_cache.py deleted file mode 100644 index d1fe8e13fa9..00000000000 --- a/invokeai/app/services/item_storage/item_storage_forward_cache.py +++ /dev/null @@ -1,61 +0,0 @@ -from queue import Queue -from typing import Optional, TypeVar - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC - -T = TypeVar("T") - - -class ItemStorageForwardCache(ItemStorageABC[T]): - """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" - - def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20): - super().__init__() - self._underlying_storage = underlying_storage - self._cache: dict[str, T] = {} - self._cache_ids = Queue[str]() - self._max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self._underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self._underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, item_id: str) -> T: - cache_item = self._get_cache(item_id) - if cache_item is not None: - return cache_item - - latent = self._underlying_storage.get(item_id) - self._set_cache(item_id, latent) - return latent - - def set(self, item: T) -> str: - item_id = self._underlying_storage.set(item) - self._set_cache(item_id, item) - self._on_changed(item) - return item_id - - def delete(self, item_id: str) -> None: - self._underlying_storage.delete(item_id) - if item_id in self._cache: - del self._cache[item_id] - self._on_deleted(item_id) - - def _get_cache(self, item_id: str) -> Optional[T]: - return None if item_id not in self._cache else self._cache[item_id] - - def _set_cache(self, item_id: str, data: T): - if item_id not in self._cache: - self._cache[item_id] = data - self._cache_ids.put(item_id) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/item_storage/item_storage_memory.py b/invokeai/app/services/item_storage/item_storage_memory.py index 6d028745164..d8dd0e06645 100644 --- a/invokeai/app/services/item_storage/item_storage_memory.py +++ b/invokeai/app/services/item_storage/item_storage_memory.py @@ -34,7 +34,7 @@ def get(self, item_id: str) -> T: self._items[item_id] = item return item - def set(self, item: T) -> str: + def set(self, item: T) -> None: item_id = getattr(item, self._id_field) if item_id in self._items: # If item already exists, remove it and add it to the end @@ -44,7 +44,6 @@ def set(self, item: T) -> str: self._items.popitem(last=False) self._items[item_id] = item self._on_changed(item) - return item_id def delete(self, item_id: str) -> None: # This is a no-op if the item doesn't exist. diff --git a/invokeai/app/services/object_serializer/object_serializer_base.py b/invokeai/app/services/object_serializer/object_serializer_base.py new file mode 100644 index 00000000000..b01a641d8fb --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_base.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import Callable, Generic, TypeVar + +T = TypeVar("T") + + +class ObjectSerializerBase(ABC, Generic[T]): + """Saves and loads arbitrary python objects.""" + + def __init__(self) -> None: + self._on_saved_callbacks: list[Callable[[str, T], None]] = [] + self._on_deleted_callbacks: list[Callable[[str], None]] = [] + + @abstractmethod + def load(self, name: str) -> T: + """ + Loads the object. + :param name: The name of the object to load. + :raises ObjectNotFoundError: if the object is not found + """ + pass + + @abstractmethod + def save(self, obj: T) -> str: + """ + Saves the object, returning its name. + :param obj: The object to save. + """ + pass + + @abstractmethod + def delete(self, name: str) -> None: + """ + Deletes the object, if it exists. + :param name: The name of the object to delete. + """ + pass + + def on_saved(self, on_saved: Callable[[str, T], None]) -> None: + """Register a callback for when an object is saved""" + self._on_saved_callbacks.append(on_saved) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an object is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_saved(self, name: str, obj: T) -> None: + for callback in self._on_saved_callbacks: + callback(name, obj) + + def _on_deleted(self, name: str) -> None: + for callback in self._on_deleted_callbacks: + callback(name) diff --git a/invokeai/app/services/object_serializer/object_serializer_common.py b/invokeai/app/services/object_serializer/object_serializer_common.py new file mode 100644 index 00000000000..7057386541f --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_common.py @@ -0,0 +1,5 @@ +class ObjectNotFoundError(KeyError): + """Raised when an object is not found while loading""" + + def __init__(self, name: str) -> None: + super().__init__(f"Object with name {name} not found") diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py new file mode 100644 index 00000000000..afa868b157f --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -0,0 +1,84 @@ +import typing +from pathlib import Path +from typing import Optional, TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.util.misc import uuid_string + +T = TypeVar("T") + + +class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): + """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. + + :param output_folder: The folder where the objects will be stored + """ + + def __init__(self, output_dir: Path): + super().__init__() + self._output_dir = output_dir + self._output_dir.mkdir(parents=True, exist_ok=True) + self.__obj_class_name: Optional[str] = None + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all() + + def load(self, name: str) -> T: + file_path = self._get_path(name) + try: + return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + except FileNotFoundError as e: + raise ObjectNotFoundError(name) from e + + def save(self, obj: T) -> str: + name = self._new_name() + file_path = self._get_path(name) + torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType] + return name + + def delete(self, name: str) -> None: + file_path = self._get_path(name) + file_path.unlink() + + @property + def _obj_class_name(self) -> str: + if not self.__obj_class_name: + # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason + self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + return self.__obj_class_name + + def _get_path(self, name: str) -> Path: + return self._output_dir / name + + def _new_name(self) -> str: + return f"{self._obj_class_name}_{uuid_string()}" + + def _delete_all(self) -> None: + """ + Deletes all objects from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have + # to manually clear them on startup anyways. This is a bit simpler and more reliable. + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_count = 0 + freed_space = 0 + for file in Path(self._output_dir).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py new file mode 100644 index 00000000000..40e34e65406 --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -0,0 +1,61 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase + +T = TypeVar("T") + + +class ObjectSerializerForwardCache(ObjectSerializerBase[T]): + """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + + def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def load(self, name: str) -> T: + cache_item = self._get_cache(name) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.load(name) + self._set_cache(name, latent) + return latent + + def save(self, obj: T) -> str: + name = self._underlying_storage.save(obj) + self._set_cache(name, obj) + self._on_saved(name, obj) + return name + + def delete(self, name: str) -> None: + self._underlying_storage.delete(name) + if name in self._cache: + del self._cache[name] + self._on_deleted(name) + + def _get_cache(self, name: str) -> Optional[T]: + return None if name not in self._cache else self._cache[name] + + def _set_cache(self, name: str, data: T): + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index baff47a3df4..8c5a821fd0f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -223,16 +223,16 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - name = self._services.tensors.set(item=tensor) - return name + tensor_id = self._services.tensors.save(obj=tensor) + return tensor_id - def get(self, tensor_name: str) -> Tensor: + def load(self, name: str) -> Tensor: """ - Gets a tensor by name. + Loads a tensor by name. - :param tensor_name: The name of the tensor to get. + :param name: The name of the tensor to load. """ - return self._services.tensors.get(tensor_name) + return self._services.tensors.load(name) class ConditioningInterface(InvocationContextInterface): @@ -243,17 +243,17 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - name = self._services.conditioning.set(item=conditioning_data) - return name + conditioning_id = self._services.conditioning.save(obj=conditioning_data) + return conditioning_id - def get(self, conditioning_name: str) -> ConditioningFieldData: + def load(self, name: str) -> ConditioningFieldData: """ - Gets conditioning data by name. + Loads conditioning data by name. - :param conditioning_name: The name of the conditioning data to get. + :param name: The name of the conditioning data to load. """ - return self._services.conditioning.get(conditioning_name) + return self._services.conditioning.load(name) class ModelsInterface(InvocationContextInterface): From 55fa7855610bc1626c9a3ec7a7cbb56833bdcd4e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:43:22 +1100 Subject: [PATCH 44/67] tidy(nodes): remove object serializer on_saved It's unused. --- .../services/object_serializer/object_serializer_base.py | 9 --------- .../object_serializer/object_serializer_forward_cache.py | 1 - 2 files changed, 10 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_base.py b/invokeai/app/services/object_serializer/object_serializer_base.py index b01a641d8fb..ff19b4a039d 100644 --- a/invokeai/app/services/object_serializer/object_serializer_base.py +++ b/invokeai/app/services/object_serializer/object_serializer_base.py @@ -8,7 +8,6 @@ class ObjectSerializerBase(ABC, Generic[T]): """Saves and loads arbitrary python objects.""" def __init__(self) -> None: - self._on_saved_callbacks: list[Callable[[str, T], None]] = [] self._on_deleted_callbacks: list[Callable[[str], None]] = [] @abstractmethod @@ -36,18 +35,10 @@ def delete(self, name: str) -> None: """ pass - def on_saved(self, on_saved: Callable[[str, T], None]) -> None: - """Register a callback for when an object is saved""" - self._on_saved_callbacks.append(on_saved) - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: """Register a callback for when an object is deleted""" self._on_deleted_callbacks.append(on_deleted) - def _on_saved(self, name: str, obj: T) -> None: - for callback in self._on_saved_callbacks: - callback(name, obj) - def _on_deleted(self, name: str) -> None: for callback in self._on_deleted_callbacks: callback(name) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 40e34e65406..2a4ecdd844b 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -41,7 +41,6 @@ def load(self, name: str) -> T: def save(self, obj: T) -> str: name = self._underlying_storage.save(obj) self._set_cache(name, obj) - self._on_saved(name, obj) return name def delete(self, name: str) -> None: From 83ddcc5f3a9f74ee7fc680346660ecb9921c43cf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:18:58 +1100 Subject: [PATCH 45/67] feat(nodes): allow `_delete_all` in obj serializer to be called at any time `_delete_all` logged how many items it deleted, and had to be called _after_ service start bc it needed access to logger. Move the logger call to the startup method and return the the deleted stats from `_delete_all`. This lets `_delete_all` be called at any time. --- .../object_serializer_ephemeral_disk.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py index afa868b157f..9545d1714d7 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -1,4 +1,5 @@ import typing +from dataclasses import dataclass from pathlib import Path from typing import Optional, TypeVar @@ -12,6 +13,12 @@ T = TypeVar("T") +@dataclass +class DeleteAllResult: + deleted_count: int + freed_space_bytes: float + + class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. @@ -26,7 +33,12 @@ def __init__(self, output_dir: Path): def start(self, invoker: Invoker) -> None: self._invoker = invoker - self._delete_all() + delete_all_result = self._delete_all() + if delete_all_result.deleted_count > 0: + freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) def load(self, name: str) -> T: file_path = self._get_path(name) @@ -58,18 +70,14 @@ def _get_path(self, name: str) -> Path: def _new_name(self) -> str: return f"{self._obj_class_name}_{uuid_string()}" - def _delete_all(self) -> None: + def _delete_all(self) -> DeleteAllResult: """ Deletes all objects from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). """ # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have # to manually clear them on startup anyways. This is a bit simpler and more reliable. - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - deleted_count = 0 freed_space = 0 for file in Path(self._output_dir).glob("*"): @@ -77,8 +85,4 @@ def _delete_all(self) -> None: freed_space += file.stat().st_size deleted_count += 1 file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) + return DeleteAllResult(deleted_count, freed_space) From 7a2b606001e4cc96436651e580440e8e4cda9d76 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:20:10 +1100 Subject: [PATCH 46/67] tests: add object serializer tests These test both object serializer and its forward cache implementation. --- .../test_object_serializer_ephemeral_disk.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 tests/test_object_serializer_ephemeral_disk.py diff --git a/tests/test_object_serializer_ephemeral_disk.py b/tests/test_object_serializer_ephemeral_disk.py new file mode 100644 index 00000000000..fffa65304f6 --- /dev/null +++ b/tests/test_object_serializer_ephemeral_disk.py @@ -0,0 +1,148 @@ +from dataclasses import dataclass +from pathlib import Path + +import pytest +import torch + +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache + + +@dataclass +class MockDataclass: + foo: str + + +@pytest.fixture +def obj_serializer(tmp_path: Path): + return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + + +@pytest.fixture +def fwd_cache(tmp_path: Path): + return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) + + +def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + assert obj_serializer._output_dir == tmp_path + + +def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert Path(obj_serializer._output_dir, obj_1_name).exists() + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert obj_serializer.load(obj_1_name).foo == "bar" + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert obj_serializer.load(obj_2_name).foo == "baz" + + with pytest.raises(ObjectNotFoundError): + obj_serializer.load("nonexistent_object_name") + + +def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + obj_serializer.delete(obj_1_name) + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + delete_all_result = obj_serializer._delete_all() + + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert not Path(obj_serializer._output_dir, obj_2_name).exists() + assert delete_all_result.deleted_count == 2 + + +def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): + obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + obj_1_loaded = obj_serializer.load(obj_1_name) + assert isinstance(obj_1_loaded, MockDataclass) + assert obj_1_loaded.foo == "bar" + assert obj_1_name.startswith("MockDataclass_") + + obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) + obj_2_name = obj_serializer.save(9001) + assert obj_serializer.load(obj_2_name) == 9001 + assert obj_2_name.startswith("int_") + + obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) + obj_3_name = obj_serializer.save("foo") + assert obj_serializer.load(obj_3_name) == "foo" + assert obj_3_name.startswith("str_") + + obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) + obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) + obj_4_loaded = obj_serializer.load(obj_4_name) + assert isinstance(obj_4_loaded, torch.Tensor) + assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) + assert obj_4_name.startswith("Tensor_") + + +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + fwd_cache = ObjectSerializerForwardCache(obj_serializer) + assert fwd_cache._underlying_storage == obj_serializer + + +def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj = MockDataclass(foo="bar") + obj_name = fwd_cache.save(obj) + obj_loaded = fwd_cache.load(obj_name) + obj_underlying = fwd_cache._underlying_storage.load(obj_name) + assert obj_loaded == obj_underlying + assert obj_loaded.foo == "bar" + + +def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = fwd_cache.save(obj_1) + obj_2 = MockDataclass(foo="baz") + obj_2_name = fwd_cache.save(obj_2) + obj_3 = MockDataclass(foo="qux") + obj_3_name = fwd_cache.save(obj_3) + assert obj_1_name not in fwd_cache._cache + assert obj_2_name in fwd_cache._cache + assert obj_3_name in fwd_cache._cache + # apparently qsize is "not reliable"? + assert fwd_cache._cache_ids.qsize() == 2 + + +def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + called_name = None + obj_1 = MockDataclass(foo="bar") + + def on_deleted(name: str): + nonlocal called_name + called_name = name + + fwd_cache.on_deleted(on_deleted) + obj_1_name = fwd_cache.save(obj_1) + fwd_cache.delete(obj_1_name) + assert called_name == obj_1_name From 1f27ddc07d8bdd2433857b898f931fbd84e0c153 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:23:47 +1100 Subject: [PATCH 47/67] tidy(nodes): minor spelling correction --- invokeai/app/services/shared/invocation_context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 8c5a821fd0f..828d3d84904 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -223,8 +223,8 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - tensor_id = self._services.tensors.save(obj=tensor) - return tensor_id + name = self._services.tensors.save(obj=tensor) + return name def load(self, name: str) -> Tensor: """ @@ -243,8 +243,8 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - conditioning_id = self._services.conditioning.save(obj=conditioning_data) - return conditioning_id + name = self._services.conditioning.save(obj=conditioning_data) + return name def load(self, name: str) -> ConditioningFieldData: """ From bcb85e100db45afea6441fe4ca617bcf4372572a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:36:53 +1100 Subject: [PATCH 48/67] tests: fix broken tests --- .../object_serializer_ephemeral_disk.py | 9 ++++++--- .../object_serializer_forward_cache.py | 10 ++++++---- tests/aa_nodes/test_graph_execution_state.py | 5 +++-- tests/aa_nodes/test_invoker.py | 3 ++- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py index 9545d1714d7..880848a1425 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -1,15 +1,18 @@ import typing from dataclasses import dataclass from pathlib import Path -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar import torch -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.util.misc import uuid_string +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + + T = TypeVar("T") @@ -31,7 +34,7 @@ def __init__(self, output_dir: Path): self._output_dir.mkdir(parents=True, exist_ok=True) self.__obj_class_name: Optional[str] = None - def start(self, invoker: Invoker) -> None: + def start(self, invoker: "Invoker") -> None: self._invoker = invoker delete_all_result = self._delete_all() if delete_all_result.deleted_count > 0: diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 2a4ecdd844b..c8ca13982c1 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,11 +1,13 @@ from queue import Queue -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase T = TypeVar("T") +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + class ObjectSerializerForwardCache(ObjectSerializerBase[T]): """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" @@ -17,13 +19,13 @@ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size - def start(self, invoker: Invoker) -> None: + def start(self, invoker: "Invoker") -> None: self._invoker = invoker start_op = getattr(self._underlying_storage, "start", None) if callable(start_op): start_op(invoker) - def stop(self, invoker: Invoker) -> None: + def stop(self, invoker: "Invoker") -> None: self._invoker = invoker stop_op = getattr(self._underlying_storage, "stop", None) if callable(stop_op): diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index aba7c5694f3..27d2d2230a3 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -60,7 +60,6 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore @@ -74,6 +73,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) @@ -89,7 +90,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B config=None, context_data=None, images=None, - latents=None, + tensors=None, logger=None, models=None, util=None, diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 2ae4eab58a0..437ea0f00d3 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -63,7 +63,6 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore @@ -77,6 +76,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) From bc5f356390c70372cc226ba097c4831c0ae77993 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 07:55:36 +1100 Subject: [PATCH 49/67] feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders --- invokeai/app/invocations/noise.py | 5 +++-- invokeai/app/invocations/primitives.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 74b3d6e4cb1..4093030388b 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,6 +5,7 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField +from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -70,8 +71,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index d0f95c92d02..2a9cb8cf9bf 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -16,6 +16,7 @@ OutputField, UIComponent, ) +from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.shared.invocation_context import InvocationContext @@ -321,8 +322,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) From 8892df1d97919c442b21ecbb28961320855ca4bd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 08:14:04 +1100 Subject: [PATCH 50/67] Revert "feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders" This reverts commit ef18fc546560277302f3886e456da9a47e8edce0. --- invokeai/app/invocations/noise.py | 5 ++--- invokeai/app/invocations/primitives.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 4093030388b..74b3d6e4cb1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,7 +5,6 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -71,8 +70,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * LATENT_SCALE_FACTOR, - height=latents.size()[2] * LATENT_SCALE_FACTOR, + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 2a9cb8cf9bf..d0f95c92d02 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -16,7 +16,6 @@ OutputField, UIComponent, ) -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.shared.invocation_context import InvocationContext @@ -322,8 +321,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * LATENT_SCALE_FACTOR, - height=latents.size()[2] * LATENT_SCALE_FACTOR, + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, ) From cb0b389b4b3fbdc45f835c6275c8c68ac963f9ff Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:57:01 +1100 Subject: [PATCH 51/67] tidy(nodes): clarify comment --- invokeai/app/services/shared/invocation_context.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 828d3d84904..3d06cf92725 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -279,15 +279,8 @@ def load( :param submodel: The submodel of the model to get. """ - # During this call, the model manager emits events with model loading status. The model - # manager itself has access to the events services, but does not have access to the - # required metadata for the events. - # - # For example, it needs access to the node's ID so that the events can be associated - # with the execution of a specific node. - # - # While this is available within the node, it's tedious to need to pass it in on every - # call. We can avoid that by wrapping the method here. + # The model manager emits events as it loads the model. It needs the context data to build + # the event payloads. return self._services.model_manager.get_model( model_name, base_model, model_type, submodel, context_data=self._context_data From ed772a71079903315b378a8d82f941effc9f7975 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:05:33 +1100 Subject: [PATCH 52/67] fix(nodes): use `metadata`/`board_id` if provided by user, overriding `WithMetadata`/`WithBoard`-provided values --- .../app/services/shared/invocation_context.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 3d06cf92725..1ca44b78625 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -167,16 +167,19 @@ def save( **Use this only if you want to override or provide metadata manually!** """ - # If the invocation inherits metadata, use that. Else, use the metadata passed in. - metadata_ = ( - self._context_data.invocation.metadata - if isinstance(self._context_data.invocation, WithMetadata) - else metadata - ) - - # If the invocation inherits WithBoard, use that. Else, use the board_id passed in. - board_ = self._context_data.invocation.board if isinstance(self._context_data.invocation, WithBoard) else None - board_id_ = board_.board_id if board_ is not None else board_id + # If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None. + metadata_ = None + if metadata: + metadata_ = metadata + elif isinstance(self._context_data.invocation, WithMetadata): + metadata_ = self._context_data.invocation.metadata + + # If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None. + board_id_ = None + if board_id: + board_id_ = board_id + elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board: + board_id_ = self._context_data.invocation.board.board_id return self._services.images.create( image=image, From c58f8c32694f6eef3322922fac94bd0f8809a87c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 16:09:59 +1100 Subject: [PATCH 53/67] feat(nodes): make delete on startup configurable for obj serializer - The default is to not delete on startup - feels safer. - The two services using this class _do_ delete on startup. - The class has "ephemeral" removed from its name. - Tests & app updated for this change. --- invokeai/app/api/dependencies.py | 8 ++- ...eral_disk.py => object_serializer_disk.py} | 22 ++++--- ...disk.py => test_object_serializer_disk.py} | 64 ++++++++++++++----- 3 files changed, 67 insertions(+), 27 deletions(-) rename invokeai/app/services/object_serializer/{object_serializer_ephemeral_disk.py => object_serializer_disk.py} (77%) rename tests/{test_object_serializer_ephemeral_disk.py => test_object_serializer_disk.py} (65%) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0c80494616f..2acb961aa7a 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -5,7 +5,7 @@ import torch from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory -from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore @@ -90,9 +90,11 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors")) + tensors = ObjectSerializerForwardCache( + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True) + ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True) ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py similarity index 77% rename from invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py rename to invokeai/app/services/object_serializer/object_serializer_disk.py index 880848a1425..174ff15192d 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -22,26 +22,30 @@ class DeleteAllResult: freed_space_bytes: float -class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): - """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. +class ObjectSerializerDisk(ObjectSerializerBase[T]): + """Provides a disk-backed storage for arbitrary python objects. :param output_folder: The folder where the objects will be stored + :param delete_on_startup: If True, all objects in the output folder will be deleted on startup """ - def __init__(self, output_dir: Path): + def __init__(self, output_dir: Path, delete_on_startup: bool = False): super().__init__() self._output_dir = output_dir self._output_dir.mkdir(parents=True, exist_ok=True) + self._delete_on_startup = delete_on_startup self.__obj_class_name: Optional[str] = None def start(self, invoker: "Invoker") -> None: self._invoker = invoker - delete_all_result = self._delete_all() - if delete_all_result.deleted_count > 0: - freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) + + if self._delete_on_startup: + delete_all_result = self._delete_all() + if delete_all_result.deleted_count > 0: + freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) def load(self, name: str) -> T: file_path = self._get_path(name) diff --git a/tests/test_object_serializer_ephemeral_disk.py b/tests/test_object_serializer_disk.py similarity index 65% rename from tests/test_object_serializer_ephemeral_disk.py rename to tests/test_object_serializer_disk.py index fffa65304f6..5ce1e579013 100644 --- a/tests/test_object_serializer_ephemeral_disk.py +++ b/tests/test_object_serializer_disk.py @@ -1,11 +1,14 @@ from dataclasses import dataclass +from logging import Logger from pathlib import Path +from unittest.mock import Mock import pytest import torch +from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError -from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -14,22 +17,31 @@ class MockDataclass: foo: str +def count_files(path: Path): + return len(list(path.iterdir())) + + @pytest.fixture def obj_serializer(tmp_path: Path): - return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + return ObjectSerializerDisk[MockDataclass](tmp_path) @pytest.fixture def fwd_cache(tmp_path: Path): - return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) + return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) + + +@pytest.fixture +def mock_invoker_with_logger(): + return Mock(Invoker, services=Mock(logger=Mock(Logger))) -def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): - obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) +def test_obj_serializer_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path -def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert Path(obj_serializer._output_dir, obj_1_name).exists() @@ -39,7 +51,7 @@ def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEph assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert obj_serializer.load(obj_1_name).foo == "bar" @@ -52,7 +64,7 @@ def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEph obj_serializer.load("nonexistent_object_name") -def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -64,7 +76,7 @@ def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerE assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -78,8 +90,30 @@ def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSeriali assert delete_all_result.deleted_count == 2 -def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): - obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) +def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) + assert obj_serializer._delete_on_startup is False + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_serializer.start(mock_invoker_with_logger) + assert Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) + assert obj_serializer._delete_on_startup is True + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_serializer.start(mock_invoker_with_logger) + assert not Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_different_types(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -88,17 +122,17 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): assert obj_1_loaded.foo == "bar" assert obj_1_name.startswith("MockDataclass_") - obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) + obj_serializer = ObjectSerializerDisk[int](tmp_path) obj_2_name = obj_serializer.save(9001) assert obj_serializer.load(obj_2_name) == 9001 assert obj_2_name.startswith("int_") - obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) + obj_serializer = ObjectSerializerDisk[str](tmp_path) obj_3_name = obj_serializer.save("foo") assert obj_serializer.load(obj_3_name) == "foo" assert obj_3_name.startswith("str_") - obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) + obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path) obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) obj_4_loaded = obj_serializer.load(obj_4_name) assert isinstance(obj_4_loaded, torch.Tensor) @@ -106,7 +140,7 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): assert obj_4_name.startswith("Tensor_") -def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]): fwd_cache = ObjectSerializerForwardCache(obj_serializer) assert fwd_cache._underlying_storage == obj_serializer From ff249a231522e7a43417a90d48fa6ab9320b7fdd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 09:41:23 +1100 Subject: [PATCH 54/67] tidy(nodes): do not store unnecessarily store invoker --- .../app/services/object_serializer/object_serializer_disk.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 174ff15192d..b3827e16a92 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -37,13 +37,11 @@ def __init__(self, output_dir: Path, delete_on_startup: bool = False): self.__obj_class_name: Optional[str] = None def start(self, invoker: "Invoker") -> None: - self._invoker = invoker - if self._delete_on_startup: delete_all_result = self._delete_all() if delete_all_result.deleted_count > 0: freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - self._invoker.services.logger.info( + invoker.services.logger.info( f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" ) From c1e5cd589364913ffd382933bf06bce26e762f49 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:06:50 +1100 Subject: [PATCH 55/67] tidy(nodes): "latents" -> "obj" --- .../object_serializer/object_serializer_forward_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index c8ca13982c1..812731f456a 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -36,9 +36,9 @@ def load(self, name: str) -> T: if cache_item is not None: return cache_item - latent = self._underlying_storage.load(name) - self._set_cache(name, latent) - return latent + obj = self._underlying_storage.load(name) + self._set_cache(name, obj) + return obj def save(self, obj: T) -> str: name = self._underlying_storage.save(obj) From 6bb2dda3f1b6c262d73d5c717d4e84bd45ac29e3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:10:48 +1100 Subject: [PATCH 56/67] chore(nodes): fix pyright ignore --- .../app/services/object_serializer/object_serializer_disk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index b3827e16a92..06f86aa460c 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -66,7 +66,7 @@ def delete(self, name: str) -> None: def _obj_class_name(self) -> str: if not self.__obj_class_name: # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason - self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue] return self.__obj_class_name def _get_path(self, name: str) -> Path: From a7207ed8cf2485842271790d0f266b626439816c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:11:31 +1100 Subject: [PATCH 57/67] chore(nodes): update ObjectSerializerForwardCache docstring --- .../object_serializer/object_serializer_forward_cache.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 812731f456a..b361259a4b1 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -10,7 +10,10 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]): - """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + """ + Provides a LRU cache for an instance of `ObjectSerializerBase`. + Saving an object to the cache always writes through to the underlying storage. + """ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): super().__init__() From 199ddd6623ea738f2d61d869fe88e5f5184cf1ee Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 18:46:51 +1100 Subject: [PATCH 58/67] tests: test ObjectSerializerDisk class name extraction --- tests/test_object_serializer_disk.py | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py index 5ce1e579013..2bc7e16937e 100644 --- a/tests/test_object_serializer_disk.py +++ b/tests/test_object_serializer_disk.py @@ -113,28 +113,31 @@ def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with def test_obj_serializer_disk_different_types(tmp_path: Path): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) - + obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) - obj_1_loaded = obj_serializer.load(obj_1_name) + obj_1_name = obj_serializer_1.save(obj_1) + obj_1_loaded = obj_serializer_1.load(obj_1_name) + assert obj_serializer_1._obj_class_name == "MockDataclass" assert isinstance(obj_1_loaded, MockDataclass) assert obj_1_loaded.foo == "bar" assert obj_1_name.startswith("MockDataclass_") - obj_serializer = ObjectSerializerDisk[int](tmp_path) - obj_2_name = obj_serializer.save(9001) - assert obj_serializer.load(obj_2_name) == 9001 + obj_serializer_2 = ObjectSerializerDisk[int](tmp_path) + obj_2_name = obj_serializer_2.save(9001) + assert obj_serializer_2._obj_class_name == "int" + assert obj_serializer_2.load(obj_2_name) == 9001 assert obj_2_name.startswith("int_") - obj_serializer = ObjectSerializerDisk[str](tmp_path) - obj_3_name = obj_serializer.save("foo") - assert obj_serializer.load(obj_3_name) == "foo" + obj_serializer_3 = ObjectSerializerDisk[str](tmp_path) + obj_3_name = obj_serializer_3.save("foo") + assert obj_serializer_3._obj_class_name == "str" + assert obj_serializer_3.load(obj_3_name) == "foo" assert obj_3_name.startswith("str_") - obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path) - obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) - obj_4_loaded = obj_serializer.load(obj_4_name) + obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path) + obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3])) + obj_4_loaded = obj_serializer_4.load(obj_4_name) + assert obj_serializer_4._obj_class_name == "Tensor" assert isinstance(obj_4_loaded, torch.Tensor) assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) assert obj_4_name.startswith("Tensor_") From b7ffd36cc6d81f832ae922fe62d7a029f10115c1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 19:11:28 +1100 Subject: [PATCH 59/67] feat(nodes): use TemporaryDirectory to handle ephemeral storage in ObjectSerializerDisk Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`. The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`. The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped. In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves. This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often. Tests updated. --- invokeai/app/api/dependencies.py | 4 +- .../object_serializer_disk.py | 56 ++++++++----------- tests/test_object_serializer_disk.py | 53 +++++++----------- 3 files changed, 46 insertions(+), 67 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 2acb961aa7a..0f2a92b5c8e 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -91,10 +91,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) tensors = ObjectSerializerForwardCache( - ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True) + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True) ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True) + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 06f86aa460c..935fec30605 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -1,3 +1,4 @@ +import tempfile import typing from dataclasses import dataclass from pathlib import Path @@ -23,28 +24,24 @@ class DeleteAllResult: class ObjectSerializerDisk(ObjectSerializerBase[T]): - """Provides a disk-backed storage for arbitrary python objects. + """Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`. - :param output_folder: The folder where the objects will be stored - :param delete_on_startup: If True, all objects in the output folder will be deleted on startup + :param output_dir: The folder where the serialized objects will be stored + :param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit """ - def __init__(self, output_dir: Path, delete_on_startup: bool = False): + def __init__(self, output_dir: Path, ephemeral: bool = False): super().__init__() - self._output_dir = output_dir - self._output_dir.mkdir(parents=True, exist_ok=True) - self._delete_on_startup = delete_on_startup + self._ephemeral = ephemeral + self._base_output_dir = output_dir + self._base_output_dir.mkdir(parents=True, exist_ok=True) + # Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows + self._tempdir = ( + tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None + ) + self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir self.__obj_class_name: Optional[str] = None - def start(self, invoker: "Invoker") -> None: - if self._delete_on_startup: - delete_all_result = self._delete_all() - if delete_all_result.deleted_count > 0: - freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - invoker.services.logger.info( - f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) - def load(self, name: str) -> T: file_path = self._get_path(name) try: @@ -75,19 +72,14 @@ def _get_path(self, name: str) -> Path: def _new_name(self) -> str: return f"{self._obj_class_name}_{uuid_string()}" - def _delete_all(self) -> DeleteAllResult: - """ - Deletes all objects from disk. - """ - - # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have - # to manually clear them on startup anyways. This is a bit simpler and more reliable. - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_dir).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - return DeleteAllResult(deleted_count, freed_space) + def _tempdir_cleanup(self) -> None: + """Calls `cleanup` on the temporary directory, if it exists.""" + if self._tempdir: + self._tempdir.cleanup() + + def __del__(self) -> None: + # In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd. + self._tempdir_cleanup() + + def stop(self, invoker: "Invoker") -> None: + self._tempdir_cleanup() diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py index 2bc7e16937e..125534c5002 100644 --- a/tests/test_object_serializer_disk.py +++ b/tests/test_object_serializer_disk.py @@ -1,12 +1,10 @@ +import tempfile from dataclasses import dataclass -from logging import Logger from pathlib import Path -from unittest.mock import Mock import pytest import torch -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -31,11 +29,6 @@ def fwd_cache(tmp_path: Path): return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) -@pytest.fixture -def mock_invoker_with_logger(): - return Mock(Invoker, services=Mock(logger=Mock(Logger))) - - def test_obj_serializer_disk_initializes(tmp_path: Path): obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path @@ -76,39 +69,33 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): - obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) +def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory) + assert obj_serializer._base_output_dir == tmp_path + assert obj_serializer._output_dir != tmp_path + assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name) - obj_2 = MockDataclass(foo="bar") - obj_2_name = obj_serializer.save(obj_2) - delete_all_result = obj_serializer._delete_all() +def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + del obj_serializer + assert not tempdir_path.exists() - assert not Path(obj_serializer._output_dir, obj_1_name).exists() - assert not Path(obj_serializer._output_dir, obj_2_name).exists() - assert delete_all_result.deleted_count == 2 +def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + obj_serializer.stop(None) # pyright: ignore [reportArgumentType] + assert not tempdir_path.exists() -def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) - assert obj_serializer._delete_on_startup is False +def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) - - obj_serializer.start(mock_invoker_with_logger) - assert Path(tmp_path, obj_1_name).exists() - - -def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) - assert obj_serializer._delete_on_startup is True - - obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) - - obj_serializer.start(mock_invoker_with_logger) + assert Path(obj_serializer._output_dir, obj_1_name).exists() assert not Path(tmp_path, obj_1_name).exists() From ba7b1b266507ce712bbf8f6209f1a50d888bf090 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 08:52:07 +1100 Subject: [PATCH 60/67] feat(nodes): extract LATENT_SCALE_FACTOR to constants.py --- invokeai/app/invocations/constants.py | 7 +++++++ invokeai/app/invocations/latent.py | 7 +------ invokeai/backend/tiles/tiles.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) create mode 100644 invokeai/app/invocations/constants.py diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py new file mode 100644 index 00000000000..95b16f0d057 --- /dev/null +++ b/invokeai/app/invocations/constants.py @@ -0,0 +1,7 @@ +LATENT_SCALE_FACTOR = 8 +""" +HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to +be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale +factor is hard-coded to a literal '8' rather than using this constant. +The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. +""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4137ab6e2f6..fedfc38402d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,6 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, @@ -79,12 +80,6 @@ SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] -# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to -# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale -# factor is hard-coded to a literal '8' rather than using this constant. -# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. -LATENT_SCALE_FACTOR = 8 - @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py index 3c400fc87ce..2757dadba20 100644 --- a/invokeai/backend/tiles/tiles.py +++ b/invokeai/backend/tiles/tiles.py @@ -3,7 +3,7 @@ import numpy as np -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend From 2005411f7e799e63a814995bea3cd32af5c77792 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 08:54:01 +1100 Subject: [PATCH 61/67] feat(nodes): use LATENT_SCALE_FACTOR in primitives.py, noise.py - LatentsOutput.build - NoiseOutput.build - Noise.width, Noise.height multiple_of --- invokeai/app/invocations/noise.py | 9 +++++---- invokeai/app/invocations/primitives.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 74b3d6e4cb1..335d3df292e 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,6 +4,7 @@ import torch from pydantic import field_validator +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -70,8 +71,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) @@ -93,13 +94,13 @@ class NoiseInvocation(BaseInvocation): ) width: int = InputField( default=512, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.width, ) height: int = InputField( default=512, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.height, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index d0f95c92d02..43422134829 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -4,6 +4,7 @@ import torch +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( ColorField, ConditioningField, @@ -321,8 +322,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) From 083a4f3faa2e4c22a2d1949141ba2425c622c511 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:27:57 +1100 Subject: [PATCH 62/67] chore(backend): rename `ModelInfo` -> `LoadedModelInfo` We have two different classes named `ModelInfo` which might need to be used by API consumers. We need to export both but have to deal with this naming collision. The `ModelInfo` I've renamed here is the one that is returned when a model is loaded. It's the object least likely to be used by API consumers. --- invokeai/app/services/events/events_base.py | 10 +++++----- .../services/model_manager/model_manager_base.py | 4 ++-- .../model_manager/model_manager_default.py | 16 ++++++++-------- .../app/services/shared/invocation_context.py | 7 ++++--- invokeai/backend/__init__.py | 9 ++++++++- invokeai/backend/model_management/__init__.py | 2 +- .../backend/model_management/model_manager.py | 6 +++--- invokeai/backend/util/test_utils.py | 10 +++++----- invokeai/invocation_api/__init__.py | 4 ++-- 9 files changed, 38 insertions(+), 30 deletions(-) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index ad08ae03956..6b441efc2bf 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -11,7 +11,7 @@ SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType @@ -201,7 +201,7 @@ def emit_model_load_completed( base_model: BaseModelType, model_type: ModelType, submodel: SubModelType, - model_info: ModelInfo, + loaded_model_info: LoadedModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( @@ -215,9 +215,9 @@ def emit_model_load_completed( "base_model": base_model, "model_type": model_type, "submodel": submodel, - "hash": model_info.hash, - "location": str(model_info.location), - "precision": str(model_info.precision), + "hash": loaded_model_info.hash, + "location": str(loaded_model_info.location), + "precision": str(loaded_model_info.precision), }, ) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index a9b53ae2242..f888c0ec973 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -14,8 +14,8 @@ from invokeai.backend.model_management import ( AddModelResult, BaseModelType, + LoadedModelInfo, MergeInterpolationMethod, - ModelInfo, ModelType, SchedulerPredictionType, SubModelType, @@ -48,7 +48,7 @@ def get_model( model_type: ModelType, submodel: Optional[SubModelType] = None, context_data: Optional[InvocationContextData] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index b641dd3f1ed..c3712abf8e6 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -16,8 +16,8 @@ from invokeai.backend.model_management import ( AddModelResult, BaseModelType, + LoadedModelInfo, MergeInterpolationMethod, - ModelInfo, ModelManager, ModelMerger, ModelNotFoundException, @@ -98,7 +98,7 @@ def get_model( model_type: ModelType, submodel: Optional[SubModelType] = None, context_data: Optional[InvocationContextData] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """ Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. @@ -114,7 +114,7 @@ def get_model( submodel=submodel, ) - model_info = self.mgr.get_model( + loaded_model_info = self.mgr.get_model( model_name, base_model, model_type, @@ -128,10 +128,10 @@ def get_model( base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info, + loaded_model_info=loaded_model_info, ) - return model_info + return loaded_model_info def model_exists( self, @@ -273,7 +273,7 @@ def _emit_load_event( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - model_info: Optional[ModelInfo] = None, + loaded_model_info: Optional[LoadedModelInfo] = None, ): if self._invoker is None: return @@ -281,7 +281,7 @@ def _emit_load_event( if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() - if model_info: + if loaded_model_info: self._invoker.services.events.emit_model_load_completed( queue_id=context_data.queue_id, queue_item_id=context_data.queue_item_id, @@ -291,7 +291,7 @@ def _emit_load_event( base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info, + loaded_model_info=loaded_model_info, ) else: self._invoker.services.events.emit_model_load_started( diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 1ca44b78625..68fb78c1430 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -13,7 +13,7 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -272,14 +272,15 @@ def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelTy def load( self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> ModelInfo: + ) -> LoadedModelInfo: """ - Loads a model, returning its `ModelInfo` object. + Loads a model. :param model_name: The name of the model to get. :param base_model: The base model of the model to get. :param model_type: The type of the model to get. :param submodel: The submodel of the model to get. + :returns: An object representing the loaded model. """ # The model manager emits events as it loads the model. It needs the context data to build diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index ae9a12edbe2..54a1843d463 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,5 +1,12 @@ """ Initialization file for invokeai.backend """ -from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401 +from .model_management import ( # noqa: F401 + BaseModelType, + LoadedModelInfo, + ModelCache, + ModelManager, + ModelType, + SubModelType, +) from .model_management.models import SilenceWarnings # noqa: F401 diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 03abf58eb46..d523a7a0c8d 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -3,7 +3,7 @@ Initialization file for invokeai.backend.model_management """ # This import must be first -from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType +from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType from .lora import ModelPatcher, ONNXModelPatcher from .model_cache import ModelCache diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 362d8d3ff55..da74ca3fb58 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -271,7 +271,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n @dataclass -class ModelInfo: +class LoadedModelInfo: context: ModelLocker name: str base_model: BaseModelType @@ -450,7 +450,7 @@ def get_model( base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """Given a model named identified in models.yaml, return an ModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml @@ -508,7 +508,7 @@ def get_model( model_hash = "" # TODO: - return ModelInfo( + return LoadedModelInfo( context=model_context, name=model_name, base_model=base_model, diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 09b9de9e984..685603cedc6 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -7,7 +7,7 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.backend.install.model_install_backend import ModelInstall -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType @@ -34,8 +34,8 @@ def install_and_load_model( base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, -) -> ModelInfo: - """Install a model if it is not already installed, then get the ModelInfo for that model. +) -> LoadedModelInfo: + """Install a model if it is not already installed, then get the LoadedModelInfo for that model. This is intended as a utility function for tests. @@ -49,9 +49,9 @@ def install_and_load_model( submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...). Returns: - ModelInfo + LoadedModelInfo """ - # If the requested model is already installed, return its ModelInfo. + # If the requested model is already installed, return its LoadedModelInfo. with contextlib.suppress(ModelNotFoundException): return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index e80bc26a003..2d3ceca11e2 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -52,7 +52,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( @@ -121,7 +121,7 @@ # invokeai.app.services.config.config_default "InvokeAIAppConfig", # invokeai.backend.model_management.model_manager - "ModelInfo", + "LoadedModelInfo", # invokeai.backend.model_management.models.base "BaseModelType", "ModelType", From 8fb77e431e8c5c1e8863ee7d76388f061592eca4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:39:36 +1100 Subject: [PATCH 63/67] chore(nodes): export model-related objects from invocation_api --- invokeai/invocation_api/__init__.py | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 2d3ceca11e2..055dd12757d 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -28,6 +28,22 @@ WithMetadata, WithWorkflow, ) +from invokeai.app.invocations.model import ( + ClipField, + CLIPOutput, + LoraInfo, + LoraLoaderOutput, + LoRAModelField, + MainModelField, + ModelInfo, + ModelLoaderOutput, + SDXLLoraLoaderOutput, + UNetField, + UNetOutput, + VaeField, + VAEModelField, + VAEOutput, +) from invokeai.app.invocations.primitives import ( BooleanCollectionOutput, BooleanOutput, @@ -87,6 +103,21 @@ "UIType", "WithMetadata", "WithWorkflow", + # invokeai.app.invocations.model + "ModelInfo", + "LoraInfo", + "UNetField", + "ClipField", + "VaeField", + "MainModelField", + "LoRAModelField", + "VAEModelField", + "UNetOutput", + "VAEOutput", + "CLIPOutput", + "ModelLoaderOutput", + "LoraLoaderOutput", + "SDXLLoraLoaderOutput", # invokeai.app.invocations.primitives "BooleanCollectionOutput", "BooleanOutput", From 321b939d0efc66729b3e7ec8948e50798a83fa42 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:43:36 +1100 Subject: [PATCH 64/67] chore(nodes): remove deprecation logic for nodes API --- .../app/services/shared/invocation_context.py | 112 ------------------ 1 file changed, 112 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 68fb78c1430..c68dc1140b2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional -from deprecated import deprecated from PIL.Image import Image from torch import Tensor @@ -334,30 +333,6 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m ) -deprecation_version = "3.7.0" -removed_version = "3.8.0" - - -def get_deprecation_reason(property_name: str, alternative: Optional[str] = None) -> str: - msg = f"{property_name} is deprecated as of v{deprecation_version}. It will be removed in v{removed_version}." - if alternative is not None: - msg += f" Use {alternative} instead." - msg += " See PLACEHOLDER_URL for details." - return msg - - -# Deprecation docstrings template. I don't think we can implement these programmatically with -# __doc__ because the IDE won't see them. - -""" -**DEPRECATED as of v3.7.0** - -PROPERTY_NAME will be removed in v3.8.0. Use ALTERNATIVE instead. See PLACEHOLDER_URL for details. - -OG_DOCSTRING -""" - - class InvocationContext: """ The `InvocationContext` provides access to various services and data for the current invocation. @@ -397,93 +372,6 @@ def __init__( self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" - @property - @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) - def services(self) -> InvocationServices: - """ - **DEPRECATED as of v3.7.0** - - `context.services` will be removed in v3.8.0. See PLACEHOLDER_URL for details. - - The invocation services. - """ - return self._services - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.graph_execution_state_id", "`context._data.session_id`"), - ) - def graph_execution_state_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.graph_execution_state_api` will be removed in v3.8.0. Use `context._data.session_id` instead. See PLACEHOLDER_URL for details. - - The ID of the session (aka graph execution state). - """ - return self._data.session_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_id`", "`context._data.queue_id`"), - ) - def queue_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_id` will be removed in v3.8.0. Use `context._data.queue_id` instead. See PLACEHOLDER_URL for details. - - The ID of the queue. - """ - return self._data.queue_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_item_id`", "`context._data.queue_item_id`"), - ) - def queue_item_id(self) -> int: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_item_id` will be removed in v3.8.0. Use `context._data.queue_item_id` instead. See PLACEHOLDER_URL for details. - - The ID of the queue item. - """ - return self._data.queue_item_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_batch_id`", "`context._data.batch_id`"), - ) - def queue_batch_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_batch_id` will be removed in v3.8.0. Use `context._data.batch_id` instead. See PLACEHOLDER_URL for details. - - The ID of the batch. - """ - return self._data.batch_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.workflow`", "`context._data.workflow`"), - ) - def workflow(self) -> Optional[WorkflowWithoutID]: - """ - **DEPRECATED as of v3.7.0** - - `context.workflow` will be removed in v3.8.0. Use `context._data.workflow` instead. See PLACEHOLDER_URL for details. - - The workflow associated with this queue item, if any. - """ - return self._data.workflow - def build_invocation_context( services: InvocationServices, From 1bbd13ead7dadef26db9b54febe7c19215723fcf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:51:25 +1100 Subject: [PATCH 65/67] chore(nodes): "SAMPLER_NAME_VALUES" -> "SCHEDULER_NAME_VALUES" This was named inaccurately. --- invokeai/app/invocations/constants.py | 7 +++++++ invokeai/app/invocations/latent.py | 10 ++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py index 95b16f0d057..795e7a3b604 100644 --- a/invokeai/app/invocations/constants.py +++ b/invokeai/app/invocations/constants.py @@ -1,3 +1,7 @@ +from typing import Literal + +from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP + LATENT_SCALE_FACTOR = 8 """ HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to @@ -5,3 +9,6 @@ factor is hard-coded to a literal '8' rather than using this constant. The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. """ + +SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +"""A literal type representing the valid scheduler names.""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fedfc38402d..69e3f055ca8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,7 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, @@ -78,12 +78,10 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] - @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): - scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) + scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) @invocation( @@ -96,7 +94,7 @@ class SchedulerOutput(BaseInvocationOutput): class SchedulerInvocation(BaseInvocation): """Selects a scheduler.""" - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, @@ -234,7 +232,7 @@ class DenoiseLatentsInvocation(BaseInvocation): description=FieldDescriptions.denoising_start, ) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, From 6c4eeaa5692466df322d300bed8fa4d97c8f5e88 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 10:06:53 +1100 Subject: [PATCH 66/67] feat(nodes): add more missing exports to invocation_api Crawled through a few custom nodes to figure out what I had missed. --- invokeai/invocation_api/__init__.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 055dd12757d..e110b5a2db3 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -7,9 +7,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Classification, invocation, invocation_output, ) +from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( BoardField, ColorField, @@ -28,6 +30,8 @@ WithMetadata, WithWorkflow, ) +from invokeai.app.invocations.latent import SchedulerOutput +from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput from invokeai.app.invocations.model import ( ClipField, CLIPOutput, @@ -68,6 +72,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -77,11 +82,14 @@ ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device +from invokeai.version import __version__ __all__ = [ # invokeai.app.invocations.baseinvocation "BaseInvocation", "BaseInvocationOutput", + "Classification", "invocation", "invocation_output", # invokeai.app.services.shared.invocation_context @@ -103,6 +111,12 @@ "UIType", "WithMetadata", "WithWorkflow", + # invokeai.app.invocations.latent + "SchedulerOutput", + # invokeai.app.invocations.metadata + "MetadataItemField", + "MetadataItemOutput", + "MetadataOutput", # invokeai.app.invocations.model "ModelInfo", "LoraInfo", @@ -157,4 +171,17 @@ "BaseModelType", "ModelType", "SubModelType", + # invokeai.app.invocations.constants + "SCHEDULER_NAME_VALUES", + # invokeai.version + "__version__", + # invokeai.backend.util.devices + "choose_precision", + "choose_torch_device", + "CPU_DEVICE", + "CUDA_DEVICE", + "MPS_DEVICE", + # invokeai.app.util.misc + "SEED_MAX", + "get_random_seed", ] From c76a6bd65f75a032a458220c55310b0592dfc90d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:02:30 +1100 Subject: [PATCH 67/67] chore(ui): regen types --- .../frontend/web/src/services/api/schema.ts | 1892 +---------------- 1 file changed, 67 insertions(+), 1825 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 45358ed97d5..1599b310c9a 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -680,70 +680,6 @@ export type components = { */ type: "add"; }; - /** - * Adjust Image Hue Plus - * @description Adjusts the Hue of an image by rotating it in the selected color space - */ - AdjustImageHuePlusInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to adjust */ - image?: components["schemas"]["ImageField"]; - /** - * Space - * @description Color space in which to rotate hue by polar coords (*: non-invertible) - * @default Okhsl - * @enum {string} - */ - space?: "HSV / HSL / RGB" | "Okhsl" | "Okhsv" | "*Oklch / Oklab" | "*LCh / CIELab" | "*UPLab (w/CIELab_to_UPLab.icc)"; - /** - * Degrees - * @description Degrees by which to rotate image hue - * @default 0 - */ - degrees?: number; - /** - * Preserve Lightness - * @description Whether to preserve CIELAB lightness values - * @default false - */ - preserve_lightness?: boolean; - /** - * Ok Adaptive Gamut - * @description Higher preserves chroma at the expense of lightness (Oklab) - * @default 0.05 - */ - ok_adaptive_gamut?: number; - /** - * Ok High Precision - * @description Use more steps in computing gamut (Oklab/Okhsv/Okhsl) - * @default true - */ - ok_high_precision?: boolean; - /** - * type - * @default img_hue_adjust_plus - * @constant - */ - type: "img_hue_adjust_plus"; - }; /** * AppConfig * @description App Config Response @@ -1454,39 +1390,6 @@ export type components = { */ type: "boolean_output"; }; - /** - * BRIA AI Background Removal - * @description Uses the new Bria 1.4 model to remove backgrounds from images. - */ - BriaRemoveBackgroundInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to crop */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default bria_bg_remove - * @constant - */ - type: "bria_bg_remove"; - }; /** * CLIPOutput * @description Base class for invocations that output a CLIP field @@ -1581,282 +1484,6 @@ export type components = { /** @description Base model (usually 'Any') */ base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; }; - /** - * CMYK Color Separation - * @description Get color images from a base color and two others that subtractively mix to obtain it - */ - CMYKColorSeparationInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * C Value - * @description Desired final cyan value - * @default 0 - */ - c_value?: number; - /** - * M Value - * @description Desired final magenta value - * @default 25 - */ - m_value?: number; - /** - * Y Value - * @description Desired final yellow value - * @default 28 - */ - y_value?: number; - /** - * K Value - * @description Desired final black value - * @default 76 - */ - k_value?: number; - /** - * C Split - * @description Desired cyan split point % [0..1.0] - * @default 0.5 - */ - c_split?: number; - /** - * M Split - * @description Desired magenta split point % [0..1.0] - * @default 1 - */ - m_split?: number; - /** - * Y Split - * @description Desired yellow split point % [0..1.0] - * @default 0 - */ - y_split?: number; - /** - * K Split - * @description Desired black split point % [0..1.0] - * @default 0.5 - */ - k_split?: number; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_separation - * @constant - */ - type: "cmyk_separation"; - }; - /** - * CMYK Merge - * @description Merge subtractive color channels (CMYK+alpha) - */ - CMYKMergeInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The c channel */ - c_channel?: components["schemas"]["ImageField"] | null; - /** @description The m channel */ - m_channel?: components["schemas"]["ImageField"] | null; - /** @description The y channel */ - y_channel?: components["schemas"]["ImageField"] | null; - /** @description The k channel */ - k_channel?: components["schemas"]["ImageField"] | null; - /** @description The alpha channel */ - alpha_channel?: components["schemas"]["ImageField"] | null; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_merge - * @constant - */ - type: "cmyk_merge"; - }; - /** - * CMYKSeparationOutput - * @description Base class for invocations that output four L-mode images (C, M, Y, K) - */ - CMYKSeparationOutput: { - /** @description Blank image of the specified color */ - color_image: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the image in pixels - */ - width: number; - /** - * Height - * @description The height of the image in pixels - */ - height: number; - /** @description Blank image of the first separated color */ - part_a: components["schemas"]["ImageField"]; - /** - * Rgb Red A - * @description R value of color part A - */ - rgb_red_a: number; - /** - * Rgb Green A - * @description G value of color part A - */ - rgb_green_a: number; - /** - * Rgb Blue A - * @description B value of color part A - */ - rgb_blue_a: number; - /** @description Blank image of the second separated color */ - part_b: components["schemas"]["ImageField"]; - /** - * Rgb Red B - * @description R value of color part B - */ - rgb_red_b: number; - /** - * Rgb Green B - * @description G value of color part B - */ - rgb_green_b: number; - /** - * Rgb Blue B - * @description B value of color part B - */ - rgb_blue_b: number; - /** - * type - * @default cmyk_separation_output - * @constant - */ - type: "cmyk_separation_output"; - }; - /** - * CMYK Split - * @description Split an image into subtractive color channels (CMYK+alpha) - */ - CMYKSplitInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to halftone */ - image?: components["schemas"]["ImageField"]; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_split - * @constant - */ - type: "cmyk_split"; - }; - /** - * CMYKSplitOutput - * @description Base class for invocations that output four L-mode images (C, M, Y, K) - */ - CMYKSplitOutput: { - /** @description Grayscale image of the cyan channel */ - c_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the magenta channel */ - m_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the yellow channel */ - y_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the k channel */ - k_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the alpha channel */ - alpha_channel: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the image in pixels - */ - width: number; - /** - * Height - * @description The height of the image in pixels - */ - height: number; - /** - * type - * @default cmyk_split_output - * @constant - */ - type: "cmyk_split_output"; - }; /** * CV2 Infill * @description Infills transparent areas of an image using OpenCV Inpainting @@ -3491,6 +3118,8 @@ export type components = { * @description Generates an openpose pose from an image using DWPose */ DWOpenposeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4014,11 +3643,20 @@ export type components = { */ priority: number; }; + /** ExposedField */ + ExposedField: { + /** Nodeid */ + nodeId: string; + /** Fieldname */ + fieldName: string; + }; /** - * Equivalent Achromatic Lightness - * @description Calculate Equivalent Achromatic Lightness from image + * FaceIdentifier + * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. */ - EquivalentAchromaticLightnessInvocation: { + FaceIdentifierInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4038,54 +3676,12 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description Image from which to get channel */ + /** @description Image to face detect */ image?: components["schemas"]["ImageField"]; /** - * type - * @default ealightness - * @constant - */ - type: "ealightness"; - }; - /** ExposedField */ - ExposedField: { - /** Nodeid */ - nodeId: string; - /** Fieldname */ - fieldName: string; - }; - /** - * FaceIdentifier - * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. - */ - FaceIdentifierInvocation: { - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image to face detect */ - image?: components["schemas"]["ImageField"]; - /** - * Minimum Confidence - * @description Minimum confidence for face detection (lower if detection is failing) - * @default 0.5 + * Minimum Confidence + * @description Minimum confidence for face detection (lower if detection is failing) + * @default 0.5 */ minimum_confidence?: number; /** @@ -4301,39 +3897,6 @@ export type components = { */ y: number; }; - /** - * Flatten Histogram (Grayscale) - * @description Scales the values of an L-mode image by scaling them to the full range 0..255 in equal proportions - */ - FlattenHistogramMono: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Single-channel image for which to flatten the histogram */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default flatten_histogram_mono - * @constant - */ - type: "flatten_histogram_mono"; - }; /** * Float Collection Primitive * @description A collection of float primitive values @@ -4683,7 +4246,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["ImageDilateOrErodeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["HandDepthMeshGraphormerProcessor"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["TextToMaskClipsegInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["OffsetLatentsInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["CMYKColorSeparationInvocation"] | components["schemas"]["NoiseImage2DInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["EquivalentAchromaticLightnessInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageBlendInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["ImageRotateInvocation"] | components["schemas"]["ShadowsHighlightsMidtonesMaskInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["BriaRemoveBackgroundInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImageValueThresholdsInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["NoiseSpectralInvocation"] | components["schemas"]["TextMaskInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["MaskedBlendLatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CMYKMergeInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["TextToMaskClipsegAdvancedInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageCompositorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["FlattenHistogramMono"] | components["schemas"]["AdjustImageHuePlusInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["CMYKSplitInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageEnhanceInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["LatentConsistencyInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageOffsetInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"]; + [key: string]: components["schemas"]["ImageCropInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["IdealSizeInvocation"]; }; /** * Edges @@ -4720,7 +4283,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["CMYKSplitOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["CMYKSeparationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["HandDepthOutput"] | components["schemas"]["ShadowsHighlightsMidtonesMasksOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ColorCollectionOutput"]; + [key: string]: components["schemas"]["MetadataOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ClipSkipInvocationOutput"]; }; /** * Errors @@ -4811,83 +4374,6 @@ export type components = { /** Detail */ detail?: components["schemas"]["ValidationError"][]; }; - /** - * Hand Depth w/ MeshGraphormer - * @description Generate hand depth maps to inpaint with using ControlNet - */ - HandDepthMeshGraphormerProcessor: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to process */ - image?: components["schemas"]["ImageField"]; - /** - * Resolution - * @description Pixel resolution for output image - * @default 512 - */ - resolution?: number; - /** - * Mask Padding - * @description Amount to pad the hand mask by - * @default 30 - */ - mask_padding?: number; - /** - * Offload - * @description Offload model after usage - * @default false - */ - offload?: boolean; - /** - * type - * @default hand_depth_mesh_graphormer_image_processor - * @constant - */ - type: "hand_depth_mesh_graphormer_image_processor"; - }; - /** - * HandDepthOutput - * @description Base class for to output Meshgraphormer results - */ - HandDepthOutput: { - /** @description Improved hands depth map */ - image: components["schemas"]["ImageField"]; - /** @description Hands area mask */ - mask: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the depth map in pixels - */ - width: number; - /** - * Height - * @description The height of the depth map in pixels - */ - height: number; - /** - * type - * @default meshgraphormer_output - * @constant - */ - type: "meshgraphormer_output"; - }; /** * HED (softedge) Processor * @description Applies HED edge detection to image @@ -5260,87 +4746,6 @@ export type components = { */ type: "ideal_size_output"; }; - /** - * Image Layer Blend - * @description Blend two images together, with optional opacity, mask, and blend modes - */ - ImageBlendInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The top image to blend */ - layer_upper?: components["schemas"]["ImageField"]; - /** - * Blend Mode - * @description Available blend modes - * @default Normal - * @enum {string} - */ - blend_mode?: "Normal" | "Lighten Only" | "Darken Only" | "Lighten Only (EAL)" | "Darken Only (EAL)" | "Hue" | "Saturation" | "Color" | "Luminosity" | "Linear Dodge (Add)" | "Subtract" | "Multiply" | "Divide" | "Screen" | "Overlay" | "Linear Burn" | "Difference" | "Hard Light" | "Soft Light" | "Vivid Light" | "Linear Light" | "Color Burn" | "Color Dodge"; - /** - * Opacity - * @description Desired opacity of the upper layer - * @default 1 - */ - opacity?: number; - /** @description Optional mask, used to restrict areas from blending */ - mask?: components["schemas"]["ImageField"] | null; - /** - * Fit To Width - * @description Scale upper layer to fit base width - * @default false - */ - fit_to_width?: boolean; - /** - * Fit To Height - * @description Scale upper layer to fit base height - * @default true - */ - fit_to_height?: boolean; - /** @description The bottom image to blend */ - layer_base?: components["schemas"]["ImageField"]; - /** - * Color Space - * @description Available color spaces for blend computations - * @default Linear RGB - * @enum {string} - */ - color_space?: "RGB" | "Linear RGB" | "HSL (RGB)" | "HSV (RGB)" | "Okhsl" | "Okhsv" | "Oklch (Oklab)" | "LCh (CIELab)"; - /** - * Adaptive Gamut - * @description Adaptive gamut clipping (0=off). Higher prioritizes chroma over lightness - * @default 0 - */ - adaptive_gamut?: number; - /** - * High Precision - * @description Use more steps in computing gamut when possible - * @default true - */ - high_precision?: boolean; - /** - * type - * @default img_blend - * @constant - */ - type: "img_blend"; - }; /** * Blur Image * @description Blurs an image @@ -5594,77 +4999,6 @@ export type components = { */ type: "image_collection_output"; }; - /** - * Image Compositor - * @description Removes backdrop from subject image then overlays subject on background image - */ - ImageCompositorInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image of the subject on a plain monochrome background */ - image_subject?: components["schemas"]["ImageField"]; - /** @description Image of a background scene */ - image_background?: components["schemas"]["ImageField"]; - /** - * Chroma Key - * @description Can be empty for corner flood select, or CSS-3 color or tuple - * @default - */ - chroma_key?: string; - /** - * Threshold - * @description Subject isolation flood-fill threshold - * @default 50 - */ - threshold?: number; - /** - * Fill X - * @description Scale base subject image to fit background width - * @default false - */ - fill_x?: boolean; - /** - * Fill Y - * @description Scale base subject image to fit background height - * @default true - */ - fill_y?: boolean; - /** - * X Offset - * @description x-offset for the subject - * @default 0 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for the subject - * @default 0 - */ - y_offset?: number; - /** - * type - * @default img_composite - * @constant - */ - type: "img_composite"; - }; /** * Convert Image Mode * @description Converts an image to a different mode. @@ -5847,10 +5181,23 @@ export type components = { board_id?: string | null; }; /** - * Image Dilate or Erode - * @description Dilate (expand) or erode (contract) an image + * ImageField + * @description An image primitive field + */ + ImageField: { + /** + * Image Name + * @description The name of the image + */ + image_name: string; + }; + /** + * Adjust Image Hue + * @description Adjusts the Hue of an image. */ - ImageDilateOrErodeInvocation: { + ImageHueAdjustmentInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5870,158 +5217,24 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description The image from which to create a mask */ + /** @description The image to adjust */ image?: components["schemas"]["ImageField"]; /** - * Lightness Only - * @description If true, only applies to image lightness (CIELa*b*) - * @default false - */ - lightness_only?: boolean; - /** - * Radius W - * @description Width (in pixels) by which to dilate(expand) or erode (contract) the image - * @default 4 - */ - radius_w?: number; - /** - * Radius H - * @description Height (in pixels) by which to dilate(expand) or erode (contract) the image - * @default 4 - */ - radius_h?: number; - /** - * Mode - * @description How to operate on the image - * @default Dilate - * @enum {string} + * Hue + * @description The degrees by which to rotate the hue, 0-360 + * @default 0 */ - mode?: "Dilate" | "Erode"; + hue?: number; /** * type - * @default img_dilate_erode + * @default img_hue_adjust * @constant */ - type: "img_dilate_erode"; + type: "img_hue_adjust"; }; /** - * Enhance Image - * @description Applies processing from PIL's ImageEnhance module. - */ - ImageEnhanceInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image for which to apply processing */ - image?: components["schemas"]["ImageField"]; - /** - * Invert - * @description Whether to invert the image colors - * @default false - */ - invert?: boolean; - /** - * Color - * @description Color enhancement factor - * @default 1 - */ - color?: number; - /** - * Contrast - * @description Contrast enhancement factor - * @default 1 - */ - contrast?: number; - /** - * Brightness - * @description Brightness enhancement factor - * @default 1 - */ - brightness?: number; - /** - * Sharpness - * @description Sharpness enhancement factor - * @default 1 - */ - sharpness?: number; - /** - * type - * @default img_enhance - * @constant - */ - type: "img_enhance"; - }; - /** - * ImageField - * @description An image primitive field - */ - ImageField: { - /** - * Image Name - * @description The name of the image - */ - image_name: string; - }; - /** - * Adjust Image Hue - * @description Adjusts the Hue of an image. - */ - ImageHueAdjustmentInvocation: { - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to adjust */ - image?: components["schemas"]["ImageField"]; - /** - * Hue - * @description The degrees by which to rotate the hue, 0-360 - * @default 0 - */ - hue?: number; - /** - * type - * @default img_hue_adjust - * @constant - */ - type: "img_hue_adjust"; - }; - /** - * Inverse Lerp Image - * @description Inverse linear interpolation of all pixels of an image + * Inverse Lerp Image + * @description Inverse linear interpolation of all pixels of an image */ ImageInverseLerpInvocation: { /** @description The board to save the image to */ @@ -6216,57 +5429,6 @@ export type components = { */ type: "img_nsfw"; }; - /** - * Offset Image - * @description Offsets an image by a given percentage (or pixel amount). - */ - ImageOffsetInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * As Pixels - * @description Interpret offsets as pixels rather than percentages - * @default false - */ - as_pixels?: boolean; - /** @description Image to be offset */ - image?: components["schemas"]["ImageField"]; - /** - * X Offset - * @description x-offset for the subject - * @default 0.5 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for the subject - * @default 0.5 - */ - y_offset?: number; - /** - * type - * @default offset_image - * @constant - */ - type: "offset_image"; - }; /** * ImageOutput * @description Base class for nodes that output a single image @@ -6432,63 +5594,6 @@ export type components = { */ type: "img_resize"; }; - /** - * Rotate/Flip Image - * @description Rotates an image by a given angle (in degrees clockwise). - */ - ImageRotateInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image to be rotated clockwise */ - image?: components["schemas"]["ImageField"]; - /** - * Degrees - * @description Angle (in degrees clockwise) by which to rotate - * @default 90 - */ - degrees?: number; - /** - * Expand To Fit - * @description If true, extends the image boundary to fit the rotated content - * @default true - */ - expand_to_fit?: boolean; - /** - * Flip Horizontal - * @description If true, flips the image horizontally - * @default false - */ - flip_horizontal?: boolean; - /** - * Flip Vertical - * @description If true, flips the image vertically - * @default false - */ - flip_vertical?: boolean; - /** - * type - * @default rotate_image - * @constant - */ - type: "rotate_image"; - }; /** * Scale Image * @description Scales an image by a factor @@ -6603,69 +5708,6 @@ export type components = { */ thumbnail_url: string; }; - /** - * Image Value Thresholds - * @description Clip image to pure black/white past specified thresholds - */ - ImageValueThresholdsInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Make light areas dark and vice versa - * @default false - */ - invert_output?: boolean; - /** - * Renormalize Values - * @description Rescale remaining values from minimum to maximum - * @default false - */ - renormalize_values?: boolean; - /** - * Lightness Only - * @description If true, only applies to image lightness (CIELa*b*) - * @default false - */ - lightness_only?: boolean; - /** - * Threshold Upper - * @description Threshold above which will be set to full value - * @default 0.5 - */ - threshold_upper?: number; - /** - * Threshold Lower - * @description Threshold below which will be set to minimum value - * @default 0.5 - */ - threshold_lower?: number; - /** - * type - * @default img_val_thresholds - * @constant - */ - type: "img_val_thresholds"; - }; /** * Add Invisible Watermark * @description Add an invisible watermark to an image @@ -8025,47 +7067,6 @@ export type components = { */ type: "tomask"; }; - /** - * Blend Latents/Noise (Masked) - * @description Blend two latents using a given alpha and mask. Latents must have same size. - */ - MaskedBlendLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Latents tensor */ - latents_a?: components["schemas"]["LatentsField"]; - /** @description Latents tensor */ - latents_b?: components["schemas"]["LatentsField"]; - /** @description Mask for blending in latents B */ - mask?: components["schemas"]["ImageField"]; - /** - * Alpha - * @description Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B. - * @default 0.5 - */ - alpha?: number; - /** - * type - * @default lmblend - * @constant - */ - type: "lmblend"; - }; /** * Mediapipe Face Processor * @description Applies mediapipe face processing to image @@ -8679,88 +7680,8 @@ export type components = { value: string | number; }; /** - * 2D Noise Image - * @description Creates an image of 2D Noise approximating the desired characteristics - */ - NoiseImage2DInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Noise Type - * @description Desired noise spectral characteristics - * @default White - * @enum {string} - */ - noise_type?: "White" | "Red" | "Blue" | "Green"; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * Seed - * @description Seed for noise generation - * @default 0 - */ - seed?: number; - /** - * Iterations - * @description Noise approx. iterations - * @default 15 - */ - iterations?: number; - /** - * Blur Threshold - * @description Threshold used in computing noise (lower is better/slower) - * @default 0.2 - */ - blur_threshold?: number; - /** - * Sigma Red - * @description Sigma for strong gaussian blur LPF for red/green - * @default 3 - */ - sigma_red?: number; - /** - * Sigma Blue - * @description Sigma for weak gaussian blur HPF for blue/green - * @default 1 - */ - sigma_blue?: number; - /** - * type - * @default noiseimg_2d - * @constant - */ - type: "noiseimg_2d"; - }; - /** - * Noise - * @description Generates latent noise. + * Noise + * @description Generates latent noise. */ NoiseInvocation: { /** @@ -8835,84 +7756,6 @@ export type components = { */ type: "noise_output"; }; - /** - * Noise (Spectral characteristics) - * @description Creates an image of 2D Noise approximating the desired characteristics - */ - NoiseSpectralInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Noise Type - * @description Desired noise spectral characteristics - * @default White - * @enum {string} - */ - noise_type?: "White" | "Red" | "Blue" | "Green"; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * Seed - * @description Seed for noise generation - * @default 0 - */ - seed?: number; - /** - * Iterations - * @description Noise approx. iterations - * @default 15 - */ - iterations?: number; - /** - * Blur Threshold - * @description Threshold used in computing noise (lower is better/slower) - * @default 0.2 - */ - blur_threshold?: number; - /** - * Sigma Red - * @description Sigma for strong gaussian blur LPF for red/green - * @default 3 - */ - sigma_red?: number; - /** - * Sigma Blue - * @description Sigma for weak gaussian blur HPF for blue/green - * @default 1 - */ - sigma_blue?: number; - /** - * type - * @default noise_spectral - * @constant - */ - type: "noise_spectral"; - }; /** * Normal BAE Processor * @description Applies NormalBae processing to image @@ -8960,107 +7803,6 @@ export type components = { */ type: "normalbae_image_processor"; }; - /** - * ONNX Latents to Image - * @description Generates an image from latents. - */ - ONNXLatentsToImageInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Denoised latents tensor */ - latents?: components["schemas"]["LatentsField"]; - /** @description VAE */ - vae?: components["schemas"]["VaeField"]; - /** - * type - * @default l2i_onnx - * @constant - */ - type: "l2i_onnx"; - }; - /** - * ONNXModelLoaderOutput - * @description Model loader output - */ - ONNXModelLoaderOutput: { - /** - * UNet - * @description UNet (scheduler, LoRAs) - */ - unet?: components["schemas"]["UNetField"]; - /** - * CLIP - * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count - */ - clip?: components["schemas"]["ClipField"]; - /** - * VAE Decoder - * @description VAE - */ - vae_decoder?: components["schemas"]["VaeField"]; - /** - * VAE Encoder - * @description VAE - */ - vae_encoder?: components["schemas"]["VaeField"]; - /** - * type - * @default model_loader_output_onnx - * @constant - */ - type: "model_loader_output_onnx"; - }; - /** ONNX Prompt (Raw) */ - ONNXPromptInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Prompt - * @description Raw prompt text (no parsing) - * @default - */ - prompt?: string; - /** @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count */ - clip?: components["schemas"]["ClipField"]; - /** - * type - * @default prompt_onnx - * @constant - */ - type: "prompt_onnx"; - }; /** * ONNXSD1Config * @description Model config for ONNX format models based on sd-1. @@ -9242,117 +7984,6 @@ export type components = { /** Upcast Attention */ upcast_attention: boolean; }; - /** - * ONNX Text to Latents - * @description Generates latents from conditionings. - */ - ONNXTextToLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Positive conditioning tensor */ - positive_conditioning?: components["schemas"]["ConditioningField"]; - /** @description Negative conditioning tensor */ - negative_conditioning?: components["schemas"]["ConditioningField"]; - /** @description Noise tensor */ - noise?: components["schemas"]["LatentsField"]; - /** - * Steps - * @description Number of steps to run - * @default 10 - */ - steps?: number; - /** - * Cfg Scale - * @description Classifier-Free Guidance scale - * @default 7.5 - */ - cfg_scale?: number | number[]; - /** - * Scheduler - * @description Scheduler to use during inference - * @default euler - * @enum {string} - */ - scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm"; - /** - * Precision - * @description Precision to use - * @default tensor(float16) - * @enum {string} - */ - precision?: "tensor(bool)" | "tensor(int8)" | "tensor(uint8)" | "tensor(int16)" | "tensor(uint16)" | "tensor(int32)" | "tensor(uint32)" | "tensor(int64)" | "tensor(uint64)" | "tensor(float16)" | "tensor(float)" | "tensor(double)"; - /** @description UNet (scheduler, LoRAs) */ - unet?: components["schemas"]["UNetField"]; - /** - * Control - * @description ControlNet(s) to apply - */ - control?: components["schemas"]["ControlField"] | components["schemas"]["ControlField"][]; - /** - * type - * @default t2l_onnx - * @constant - */ - type: "t2l_onnx"; - }; - /** - * Offset Latents - * @description Offsets a latents tensor by a given percentage of height/width. - */ - OffsetLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Latents tensor */ - latents?: components["schemas"]["LatentsField"]; - /** - * X Offset - * @description Approx percentage to offset (H) - * @default 0.5 - */ - x_offset?: number; - /** - * Y Offset - * @description Approx percentage to offset (V) - * @default 0.5 - */ - y_offset?: number; - /** - * type - * @default offset_latents - * @constant - */ - type: "offset_latents"; - }; /** OffsetPaginatedResults[BoardDTO] */ OffsetPaginatedResults_BoardDTO_: { /** @@ -9392,58 +8023,12 @@ export type components = { * Total * @description Total number of items in result */ - total: number; - /** - * Items - * @description Items - */ - items: components["schemas"]["ImageDTO"][]; - }; - /** - * OnnxModelField - * @description Onnx model field - */ - OnnxModelField: { - /** - * Model Name - * @description Name of the model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Model Type */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - }; - /** - * ONNX Main Model - * @description Loads a main model, outputting its submodels. - */ - OnnxModelLoaderInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description ONNX Main model (UNet, VAE, CLIP) to load */ - model: components["schemas"]["OnnxModelField"]; + total: number; /** - * type - * @default onnx_model_loader - * @constant + * Items + * @description Items */ - type: "onnx_model_loader"; + items: components["schemas"]["ImageDTO"][]; }; /** PaginatedResults[ModelSummary] */ PaginatedResults_ModelSummary_: { @@ -10852,106 +9437,6 @@ export type components = { */ total: number; }; - /** - * Shadows/Highlights/Midtones - * @description Extract a Shadows/Highlights/Midtones mask from an image - */ - ShadowsHighlightsMidtonesMaskInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image from which to extract mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Highlight Threshold - * @description Threshold beyond which mask values will be at extremum - * @default 0.75 - */ - highlight_threshold?: number; - /** - * Upper Mid Threshold - * @description Threshold to which to extend mask border by 0..1 gradient - * @default 0.7 - */ - upper_mid_threshold?: number; - /** - * Lower Mid Threshold - * @description Threshold to which to extend mask border by 0..1 gradient - * @default 0.3 - */ - lower_mid_threshold?: number; - /** - * Shadow Threshold - * @description Threshold beyond which mask values will be at extremum - * @default 0.25 - */ - shadow_threshold?: number; - /** - * Mask Expand Or Contract - * @description Pixels to grow (or shrink) the mask areas - * @default 0 - */ - mask_expand_or_contract?: number; - /** - * Mask Blur - * @description Gaussian blur radius to apply to the masks - * @default 0 - */ - mask_blur?: number; - /** - * type - * @default shmmask - * @constant - */ - type: "shmmask"; - }; - /** ShadowsHighlightsMidtonesMasksOutput */ - ShadowsHighlightsMidtonesMasksOutput: { - /** @description Soft-edged highlights mask */ - highlights_mask?: components["schemas"]["ImageField"]; - /** @description Soft-edged midtones mask */ - midtones_mask?: components["schemas"]["ImageField"]; - /** @description Soft-edged shadows mask */ - shadows_mask?: components["schemas"]["ImageField"]; - /** - * Width - * @description Width of the input/outputs - */ - width: number; - /** - * Height - * @description Height of the input/outputs - */ - height: number; - /** - * type - * @default shmmask_output - * @constant - */ - type: "shmmask_output"; - }; /** * Show Image * @description Displays a provided image using the OS image viewer, and passes it forward in the pipeline. @@ -11833,249 +10318,6 @@ export type components = { /** Right */ right: number; }; - /** - * Text Mask - * @description Creates a 2D rendering of a text mask from a given font - */ - TextMaskInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Width - * @description The width of the desired mask - * @default 512 - */ - width?: number; - /** - * Height - * @description The height of the desired mask - * @default 512 - */ - height?: number; - /** - * Text - * @description The text to render - * @default - */ - text?: string; - /** - * Font - * @description Path to a FreeType-supported TTF/OTF font file - * @default - */ - font?: string; - /** - * Size - * @description Desired point size of text to use - * @default 64 - */ - size?: number; - /** - * Angle - * @description Angle of rotation to apply to the text - * @default 0 - */ - angle?: number; - /** - * X Offset - * @description x-offset for text rendering - * @default 24 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for text rendering - * @default 36 - */ - y_offset?: number; - /** - * Invert - * @description Whether to invert color of the output - * @default false - */ - invert?: boolean; - /** - * type - * @default text_mask - * @constant - */ - type: "text_mask"; - }; - /** - * Text to Mask Advanced (Clipseg) - * @description Uses the Clipseg model to generate an image mask from a text prompt - */ - TextToMaskClipsegAdvancedInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Prompt 1 - * @description First prompt with which to create a mask - */ - prompt_1?: string; - /** - * Prompt 2 - * @description Second prompt with which to create a mask (optional) - */ - prompt_2?: string; - /** - * Prompt 3 - * @description Third prompt with which to create a mask (optional) - */ - prompt_3?: string; - /** - * Prompt 4 - * @description Fourth prompt with which to create a mask (optional) - */ - prompt_4?: string; - /** - * Combine - * @description How to combine the results - * @default or - * @enum {string} - */ - combine?: "or" | "and" | "none (rgba multiplex)"; - /** - * Smoothing - * @description Radius of blur to apply before thresholding - * @default 4 - */ - smoothing?: number; - /** - * Subject Threshold - * @description Threshold above which is considered the subject - * @default 1 - */ - subject_threshold?: number; - /** - * Background Threshold - * @description Threshold below which is considered the background - * @default 0 - */ - background_threshold?: number; - /** - * type - * @default txt2mask_clipseg_adv - * @constant - */ - type: "txt2mask_clipseg_adv"; - }; - /** - * Text to Mask (Clipseg) - * @description Uses the Clipseg model to generate an image mask from a text prompt - */ - TextToMaskClipsegInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Prompt - * @description The prompt with which to create a mask - */ - prompt?: string; - /** - * Smoothing - * @description Radius of blur to apply before thresholding - * @default 4 - */ - smoothing?: number; - /** - * Subject Threshold - * @description Threshold above which is considered the subject - * @default 0.4 - */ - subject_threshold?: number; - /** - * Background Threshold - * @description Threshold below which is considered the background - * @default 0.4 - */ - background_threshold?: number; - /** - * Mask Expand Or Contract - * @description Pixels by which to grow (or shrink) mask after thresholding - * @default 0 - */ - mask_expand_or_contract?: number; - /** - * Mask Blur - * @description Radius of blur to apply after thresholding - * @default 0 - */ - mask_blur?: number; - /** - * type - * @default txt2mask_clipseg - * @constant - */ - type: "txt2mask_clipseg"; - }; /** * TextualInversionConfig * @description Model config for textual inversion embeddings. @@ -13067,53 +11309,53 @@ export type components = { */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * T2IAdapterModelFormat + * StableDiffusionXLModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** - * IPAdapterModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + T2IAdapterModelFormat: "diffusers"; /** - * StableDiffusionXLModelFormat + * CLIPVisionModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; + CLIPVisionModelFormat: "diffusers"; /** - * StableDiffusionOnnxModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * CLIPVisionModelFormat + * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ - CLIPVisionModelFormat: "diffusers"; + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * IPAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + IPAdapterModelFormat: "invokeai"; /** - * StableDiffusion1ModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** - * ControlNetModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; }; responses: never; parameters: never;