diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 149d47fb962..65607c436a5 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -151,6 +151,8 @@ def custom_openapi() -> dict[str, Any]: # TODO: note that we assume the schema_key here is the TYPE.__name__ # This could break in some cases, figure out a better way to do it output_type_titles[schema_key] = output_schema["title"] + openapi_schema["components"]["schemas"][schema_key] = output_schema + openapi_schema["components"]["schemas"][schema_key]["class"] = "output" # Add Node Editor UI helper schemas ui_config_schemas = models_json_schema( @@ -173,7 +175,6 @@ def custom_openapi() -> dict[str, Any]: outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} invoker_schema["output"] = outputs_ref invoker_schema["class"] = "invocation" - openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output" # This code no longer seems to be necessary? # Leave it here just in case diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 3243714937f..5edae5342df 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -8,13 +8,26 @@ from abc import ABC, abstractmethod from enum import Enum from inspect import signature -from types import UnionType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + ClassVar, + Iterable, + Literal, + Optional, + Type, + TypeVar, + Union, + cast, +) import semver -from pydantic import BaseModel, ConfigDict, Field, create_model +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined +from typing_extensions import TypeAliasType from invokeai.app.invocations.fields import ( FieldKind, @@ -84,6 +97,7 @@ class BaseInvocationOutput(BaseModel): """ _output_classes: ClassVar[set[BaseInvocationOutput]] = set() + _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: @@ -96,10 +110,14 @@ def get_outputs(cls) -> Iterable[BaseInvocationOutput]: return cls._output_classes @classmethod - def get_outputs_union(cls) -> UnionType: - """Gets a union of all invocation outputs.""" - outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type] - return outputs_union # type: ignore [return-value] + def get_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantc TypeAdapter for the union of all invocation output types.""" + if not cls._typeadapter: + InvocationOutputsUnion = TypeAliasType( + "InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] + ) + cls._typeadapter = TypeAdapter(InvocationOutputsUnion) + return cls._typeadapter @classmethod def get_output_types(cls) -> Iterable[str]: @@ -148,6 +166,7 @@ class BaseInvocation(ABC, BaseModel): """ _invocation_classes: ClassVar[set[BaseInvocation]] = set() + _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None @classmethod def get_type(cls) -> str: @@ -160,10 +179,14 @@ def register_invocation(cls, invocation: BaseInvocation) -> None: cls._invocation_classes.add(invocation) @classmethod - def get_invocations_union(cls) -> UnionType: - """Gets a union of all invocation types.""" - invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type] - return invocations_union # type: ignore [return-value] + def get_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantc TypeAdapter for the union of all invocation types.""" + if not cls._typeadapter: + InvocationsUnion = TypeAliasType( + "InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] + ) + cls._typeadapter = TypeAdapter(InvocationsUnion) + return cls._typeadapter @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 517da4375e1..47be380626b 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -417,7 +417,7 @@ class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") - skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) + skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers) def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers diff --git a/invokeai/app/services/shared/default_graphs.py b/invokeai/app/services/shared/default_graphs.py deleted file mode 100644 index 7e62c6d0a1b..00000000000 --- a/invokeai/app/services/shared/default_graphs.py +++ /dev/null @@ -1,92 +0,0 @@ -from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC - -from ...invocations.compel import CompelInvocation -from ...invocations.image import ImageNSFWBlurInvocation -from ...invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation -from ...invocations.noise import NoiseInvocation -from ...invocations.primitives import IntegerInvocation -from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph - -default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74" - - -def create_text_to_image() -> LibraryGraph: - graph = Graph( - nodes={ - "width": IntegerInvocation(id="width", value=512), - "height": IntegerInvocation(id="height", value=512), - "seed": IntegerInvocation(id="seed", value=-1), - "3": NoiseInvocation(id="3"), - "4": CompelInvocation(id="4"), - "5": CompelInvocation(id="5"), - "6": DenoiseLatentsInvocation(id="6"), - "7": LatentsToImageInvocation(id="7"), - "8": ImageNSFWBlurInvocation(id="8"), - }, - edges=[ - Edge( - source=EdgeConnection(node_id="width", field="value"), - destination=EdgeConnection(node_id="3", field="width"), - ), - Edge( - source=EdgeConnection(node_id="height", field="value"), - destination=EdgeConnection(node_id="3", field="height"), - ), - Edge( - source=EdgeConnection(node_id="seed", field="value"), - destination=EdgeConnection(node_id="3", field="seed"), - ), - Edge( - source=EdgeConnection(node_id="3", field="noise"), - destination=EdgeConnection(node_id="6", field="noise"), - ), - Edge( - source=EdgeConnection(node_id="6", field="latents"), - destination=EdgeConnection(node_id="7", field="latents"), - ), - Edge( - source=EdgeConnection(node_id="4", field="conditioning"), - destination=EdgeConnection(node_id="6", field="positive_conditioning"), - ), - Edge( - source=EdgeConnection(node_id="5", field="conditioning"), - destination=EdgeConnection(node_id="6", field="negative_conditioning"), - ), - Edge( - source=EdgeConnection(node_id="7", field="image"), - destination=EdgeConnection(node_id="8", field="image"), - ), - ], - ) - return LibraryGraph( - id=default_text_to_image_graph_id, - name="t2i", - description="Converts text to an image", - graph=graph, - exposed_inputs=[ - ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"), - ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"), - ExposedNodeInput(node_path="width", field="value", alias="width"), - ExposedNodeInput(node_path="height", field="value", alias="height"), - ExposedNodeInput(node_path="seed", field="value", alias="seed"), - ], - exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")], - ) - - -def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]: - """Creates the default system graphs, or adds new versions if the old ones don't match""" - - # TODO: Uncomment this when we are ready to fix this up to prevent breaking changes - graphs: list[LibraryGraph] = [] - - text_to_image = graph_library.get(default_text_to_image_graph_id) - - # TODO: Check if the graph is the same as the default one, and if not, update it - # if text_to_image is None: - text_to_image = create_text_to_image() - graph_library.set(text_to_image) - - graphs.append(text_to_image) - - return graphs diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 3df230f5ee7..e3941d9ca37 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -5,8 +5,14 @@ from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx -from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from pydantic import ( + BaseModel, + GetJsonSchemaHandler, + field_validator, +) from pydantic.fields import Field +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import CoreSchema # Importing * is bad karma but needed here for node detection from invokeai.app.invocations import * # noqa: F401 F403 @@ -176,10 +182,6 @@ class NodeIdMismatchError(ValueError): pass -class InvalidSubGraphError(ValueError): - pass - - class CyclicalGraphError(ValueError): pass @@ -188,25 +190,6 @@ class UnknownGraphValidationError(ValueError): pass -# TODO: Create and use an Empty output? -@invocation_output("graph_output") -class GraphInvocationOutput(BaseInvocationOutput): - pass - - -# TODO: Fill this out and move to invocations -@invocation("graph", version="1.0.0") -class GraphInvocation(BaseInvocation): - """Execute a graph""" - - # TODO: figure out how to create a default here - graph: "Graph" = InputField(description="The graph to run", default=None) - - def invoke(self, context: InvocationContext) -> GraphInvocationOutput: - """Invoke with provided services and return outputs.""" - return GraphInvocationOutput() - - @invocation_output("iterate_output") class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" @@ -260,21 +243,73 @@ def invoke(self, context: InvocationContext) -> CollectInvocationOutput: return CollectInvocationOutput(collection=copy.copy(self.collection)) -InvocationsUnion: Any = BaseInvocation.get_invocations_union() -InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union() - - class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=uuid_string) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me - nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( - description="The nodes in this graph", default_factory=dict - ) + nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) + @field_validator("nodes", mode="plain") + @classmethod + def validate_nodes(cls, v: dict[str, Any]): + """Validates the nodes in the graph by retrieving a union of all node types and validating each node.""" + + # Invocations register themselves as their python modules are executed. The union of all invocations is + # constructed at runtime. We use pydantic to validate `Graph.nodes` using that union. + # + # It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If + # we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing + # invocations will cause a graph to fail if they are used. + # + # We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the + # pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime. + # + # This same pattern is used in `GraphExecutionState`. + + nodes: dict[str, BaseInvocation] = {} + typeadapter = BaseInvocation.get_typeadapter() + for node_id, node in v.items(): + nodes[node_id] = typeadapter.validate_python(node) + return nodes + + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for + # fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to + # the generated schema as options for the `nodes` field. + # + # The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and + # with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as + # expected. + # + # You might be tempted to do something like this: + # + # ```py + # cloned_model = create_model(cls.__name__, __base__=cls, nodes=...) + # delattr(cloned_model, "validate_nodes") + # cloned_model.model_rebuild(force=True) + # json_schema = handler(cloned_model.__pydantic_core_schema__) + # ``` + # + # Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts + # to build the JSON Schema for the cloned model. Instead, we have to manually clone the model. + # + # This same pattern is used in `GraphExecutionState`. + + class Graph(BaseModel): + id: Optional[str] = Field(default=None, description="The id of this graph") + nodes: dict[ + str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")] + ] = Field(description="The nodes in this graph") + edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph") + + json_schema = handler(Graph.__pydantic_core_schema__) + json_schema = handler.resolve_ref_schema(json_schema) + return json_schema + def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -286,41 +321,21 @@ def add_node(self, node: BaseInvocation) -> None: self.nodes[node.id] = node - def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]: - """Returns the graph and node id for a node path.""" - # Materialized graphs may have nodes at the top level - if node_path in self.nodes: - return (self, node_path) - - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - if node_id not in self.nodes: - raise NodeNotFoundError(f"Node {node_path} not found in graph") - - node = self.nodes[node_id] - - if not isinstance(node, GraphInvocation): - # There's more node path left but this isn't a graph - failure - raise NodeNotFoundError("Node path terminated early at a non-graph node") - - return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :]) - - def delete_node(self, node_path: str) -> None: + def delete_node(self, node_id: str) -> None: """Deletes a node from a graph""" try: - graph, node_id = self._get_graph_and_node(node_path) - # Delete edges for this node - input_edges = self._get_input_edges_and_graphs(node_path) - output_edges = self._get_output_edges_and_graphs(node_path) + input_edges = self._get_input_edges(node_id) + output_edges = self._get_output_edges(node_id) - for edge_graph, _, edge in input_edges: - edge_graph.delete_edge(edge) + for edge in input_edges: + self.delete_edge(edge) - for edge_graph, _, edge in output_edges: - edge_graph.delete_edge(edge) + for edge in output_edges: + self.delete_edge(edge) - del graph.nodes[node_id] + del self.nodes[node_id] except NodeNotFoundError: pass # Ignore, not doesn't exist (should this throw?) @@ -370,13 +385,6 @@ def validate_self(self) -> None: if k != v.id: raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}") - # Validate all subgraphs - for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)): - try: - gn.graph.validate_self() - except Exception as e: - raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e - # Validate that all edges match nodes and fields in the graph for edge in self.edges: source_node = self.nodes.get(edge.source.node_id, None) @@ -438,7 +446,6 @@ def is_valid(self) -> bool: except ( DuplicateNodeIdError, NodeIdMismatchError, - InvalidSubGraphError, NodeNotFoundError, NodeFieldNotFoundError, CyclicalGraphError, @@ -459,7 +466,7 @@ def _is_destination_field_list_of_Any(self, edge: Edge) -> bool: def _validate_edge(self, edge: Edge): """Validates that a new edge doesn't create a cycle in the graph""" - # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) + # Validate that the nodes exist try: from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) @@ -526,171 +533,90 @@ def _validate_edge(self, edge: Edge): f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) - def has_node(self, node_path: str) -> bool: + def has_node(self, node_id: str) -> bool: """Determines whether or not a node exists in the graph.""" try: - n = self.get_node(node_path) - if n is not None: - return True - else: - return False + _ = self.get_node(node_id) + return True except NodeNotFoundError: return False - def get_node(self, node_path: str) -> BaseInvocation: - """Gets a node from the graph using a node path.""" - # Materialized graphs may have nodes at the top level - graph, node_id = self._get_graph_and_node(node_path) - return graph.nodes[node_id] - - def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str: - return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}" + def get_node(self, node_id: str) -> BaseInvocation: + """Gets a node from the graph.""" + try: + return self.nodes[node_id] + except KeyError as e: + raise NodeNotFoundError(f"Node {node_id} not found in graph") from e - def update_node(self, node_path: str, new_node: BaseInvocation) -> None: + def update_node(self, node_id: str, new_node: BaseInvocation) -> None: """Updates a node in the graph.""" - graph, node_id = self._get_graph_and_node(node_path) - node = graph.nodes[node_id] + node = self.nodes[node_id] # Ensure the node type matches the new node if type(node) is not type(new_node): - raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}") + raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}") # Ensure the new id is either the same or is not in the graph - prefix = None if "." not in node_path else node_path[: node_path.rindex(".")] - new_path = self._get_node_path(new_node.id, prefix=prefix) - if new_node.id != node.id and self.has_node(new_path): - raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph") + if new_node.id != node.id and self.has_node(new_node.id): + raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph") # Set the new node in the graph - graph.nodes[new_node.id] = new_node + self.nodes[new_node.id] = new_node if new_node.id != node.id: - input_edges = self._get_input_edges_and_graphs(node_path) - output_edges = self._get_output_edges_and_graphs(node_path) + input_edges = self._get_input_edges(node_id) + output_edges = self._get_output_edges(node_id) # Delete node and all edges - graph.delete_node(node_path) + self.delete_node(node_id) # Create new edges for each input and output - for graph, _, edge in input_edges: - # Remove the graph prefix from the node path - new_graph_node_path = ( - new_node.id - if "." not in edge.destination.node_id - else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}' - ) - graph.add_edge( + for edge in input_edges: + self.add_edge( Edge( source=edge.source, - destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field), + destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field), ) ) - for graph, _, edge in output_edges: - # Remove the graph prefix from the node path - new_graph_node_path = ( - new_node.id - if "." not in edge.source.node_id - else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}' - ) - graph.add_edge( + for edge in output_edges: + self.add_edge( Edge( - source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field), + source=EdgeConnection(node_id=new_node.id, field=edge.source.field), destination=edge.destination, ) ) - def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]: - """Gets all input edges for a node""" - edges = self._get_input_edges_and_graphs(node_path) - - # Filter to edges that match the field - filtered_edges = (e for e in edges if field is None or e[2].destination.field == field) - - # Create full node paths for each edge - return [ - Edge( - source=EdgeConnection( - node_id=self._get_node_path(e.source.node_id, prefix=prefix), - field=e.source.field, - ), - destination=EdgeConnection( - node_id=self._get_node_path(e.destination.node_id, prefix=prefix), - field=e.destination.field, - ), - ) - for _, prefix, e in filtered_edges - ] + def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]: + """Gets all input edges for a node. If field is provided, only edges to that field are returned.""" - def _get_input_edges_and_graphs( - self, node_path: str, prefix: Optional[str] = None - ) -> list[tuple["Graph", Union[str, None], Edge]]: - """Gets all input edges for a node along with the graph they are in and the graph's path""" - edges = [] + edges = [e for e in self.edges if e.destination.node_id == node_id] - # Return any input edges that appear in this graph - edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]) + if field is None: + return edges - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - node = self.nodes[node_id] + filtered_edges = [e for e in edges if e.destination.field == field] - if isinstance(node, GraphInvocation): - graph = node.graph - graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) - graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path) - edges.extend(graph_edges) - - return edges - - def _get_output_edges(self, node_path: str, field: str) -> list[Edge]: - """Gets all output edges for a node""" - edges = self._get_output_edges_and_graphs(node_path) - - # Filter to edges that match the field - filtered_edges = (e for e in edges if e[2].source.field == field) - - # Create full node paths for each edge - return [ - Edge( - source=EdgeConnection( - node_id=self._get_node_path(e.source.node_id, prefix=prefix), - field=e.source.field, - ), - destination=EdgeConnection( - node_id=self._get_node_path(e.destination.node_id, prefix=prefix), - field=e.destination.field, - ), - ) - for _, prefix, e in filtered_edges - ] + return filtered_edges - def _get_output_edges_and_graphs( - self, node_path: str, prefix: Optional[str] = None - ) -> list[tuple["Graph", Union[str, None], Edge]]: - """Gets all output edges for a node along with the graph they are in and the graph's path""" - edges = [] + def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]: + """Gets all output edges for a node. If field is provided, only edges from that field are returned.""" + edges = [e for e in self.edges if e.source.node_id == node_id] - # Return any input edges that appear in this graph - edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path]) - - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - node = self.nodes[node_id] + if field is None: + return edges - if isinstance(node, GraphInvocation): - graph = node.graph - graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) - graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path) - edges.extend(graph_edges) + filtered_edges = [e for e in edges if e.source.field == field] - return edges + return filtered_edges def _is_iterator_connection_valid( self, - node_path: str, + node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = [e.source for e in self._get_input_edges(node_path, "collection")] - outputs = [e.destination for e in self._get_output_edges(node_path, "item")] + inputs = [e.source for e in self._get_input_edges(node_id, "collection")] + outputs = [e.destination for e in self._get_output_edges(node_id, "item")] if new_input is not None: inputs.append(new_input) @@ -718,12 +644,12 @@ def _is_iterator_connection_valid( def _is_collector_connection_valid( self, - node_path: str, + node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = [e.source for e in self._get_input_edges(node_path, "item")] - outputs = [e.destination for e in self._get_output_edges(node_path, "collection")] + inputs = [e.source for e in self._get_input_edges(node_id, "item")] + outputs = [e.destination for e in self._get_output_edges(node_id, "collection")] if new_input is not None: inputs.append(new_input) @@ -779,27 +705,17 @@ def nx_graph_with_data(self) -> nx.DiGraph: g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges}) return g - def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: + def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph: """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" g = nx_graph or nx.DiGraph() # Add all nodes from this graph except graph/iteration nodes - g.add_nodes_from( - [ - self._get_node_path(n.id, prefix) - for n in self.nodes.values() - if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation) - ] - ) - - # Expand graph nodes - for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)): - g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) + g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)]) # TODO: figure out if iteration nodes need to be expanded unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges} - g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) + g.add_edges_from([(e[0], e[1]) for e in unique_edges]) return g @@ -824,9 +740,7 @@ class GraphExecutionState(BaseModel): ) # The results of executed nodes - results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field( - description="The results of node executions", default_factory=dict - ) + results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) # Errors raised when executing nodes errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) @@ -843,27 +757,51 @@ class GraphExecutionState(BaseModel): default_factory=dict, ) + @field_validator("results", mode="plain") + @classmethod + def validate_results(cls, v: dict[str, BaseInvocationOutput]): + """Validates the results in the GES by retrieving a union of all output types and validating each result.""" + + # See the comment in `Graph.validate_nodes` for an explanation of this logic. + results: dict[str, BaseInvocationOutput] = {} + typeadapter = BaseInvocationOutput.get_typeadapter() + for result_id, result in v.items(): + results[result_id] = typeadapter.validate_python(result) + return results + @field_validator("graph") def graph_is_valid(cls, v: Graph): """Validates that the graph is valid""" v.validate_self() return v - model_config = ConfigDict( - json_schema_extra={ - "required": [ - "id", - "graph", - "execution_graph", - "executed", - "executed_history", - "results", - "errors", - "prepared_source_mapping", - "source_prepared_mapping", - ] - } - ) + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic. + class GraphExecutionState(BaseModel): + """Tracks the state of a graph execution""" + + id: str = Field(description="The id of the execution state") + graph: Graph = Field(description="The graph being executed") + execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes") + executed: set[str] = Field(description="The set of node ids that have been executed") + executed_history: list[str] = Field( + description="The list of node ids that have been executed, in order of execution" + ) + results: dict[ + str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")] + ] = Field(description="The results of node executions") + errors: dict[str, str] = Field(description="Errors raised when executing nodes") + prepared_source_mapping: dict[str, str] = Field( + description="The map of prepared nodes to original graph nodes" + ) + source_prepared_mapping: dict[str, set[str]] = Field( + description="The map of original graph nodes to prepared nodes" + ) + + json_schema = handler(GraphExecutionState.__pydantic_core_schema__) + json_schema = handler.resolve_ref_schema(json_schema) + return json_schema def next(self) -> Optional[BaseInvocation]: """Gets the next node ready to execute.""" @@ -919,17 +857,17 @@ def has_error(self) -> bool: """Returns true if the graph has any errors""" return len(self.errors) > 0 - def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: + def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: """Prepares an iteration node and connects all edges, returning the new node id""" - node = self.graph.get_node(node_path) + node = self.graph.get_node(node_id) self_iteration_count = -1 # If this is an iterator node, we must create a copy for each iteration if isinstance(node, IterateInvocation): # Get input collection edge (should error if there are no inputs) - input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection"))) + input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection"))) input_collection_prepared_node_id = next( n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id ) @@ -943,7 +881,7 @@ def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[ return new_nodes # Get all input edges - input_edges = self.graph._get_input_edges(node_path) + input_edges = self.graph._get_input_edges(node_id) # Create new edges for this iteration # For collect nodes, this may contain multiple inputs to the same field @@ -970,10 +908,10 @@ def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[ # Add to execution graph self.execution_graph.add_node(new_node) - self.prepared_source_mapping[new_node.id] = node_path - if node_path not in self.source_prepared_mapping: - self.source_prepared_mapping[node_path] = set() - self.source_prepared_mapping[node_path].add(new_node.id) + self.prepared_source_mapping[new_node.id] = node_id + if node_id not in self.source_prepared_mapping: + self.source_prepared_mapping[node_id] = set() + self.source_prepared_mapping[node_id].add(new_node.id) # Add new edges to execution graph for edge in new_edges: @@ -1077,13 +1015,13 @@ def _prepare(self) -> Optional[str]: def _get_iteration_node( self, - source_node_path: str, + source_node_id: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str], ) -> Optional[str]: """Gets the prepared version of the specified source node that matches every iteration specified""" - prepared_nodes = self.source_prepared_mapping[source_node_path] + prepared_nodes = self.source_prepared_mapping[source_node_id] if len(prepared_nodes) == 1: return next(iter(prepared_nodes)) @@ -1094,7 +1032,7 @@ def _get_iteration_node( # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] - parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)] + parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)] return next( (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), @@ -1163,19 +1101,19 @@ def _is_node_updatable(self, node_id: str) -> bool: def add_node(self, node: BaseInvocation) -> None: self.graph.add_node(node) - def update_node(self, node_path: str, new_node: BaseInvocation) -> None: - if not self._is_node_updatable(node_path): + def update_node(self, node_id: str, new_node: BaseInvocation) -> None: + if not self._is_node_updatable(node_id): raise NodeAlreadyExecutedError( - f"Node {node_path} has already been prepared or executed and cannot be updated" + f"Node {node_id} has already been prepared or executed and cannot be updated" ) - self.graph.update_node(node_path, new_node) + self.graph.update_node(node_id, new_node) - def delete_node(self, node_path: str) -> None: - if not self._is_node_updatable(node_path): + def delete_node(self, node_id: str) -> None: + if not self._is_node_updatable(node_id): raise NodeAlreadyExecutedError( - f"Node {node_path} has already been prepared or executed and cannot be deleted" + f"Node {node_id} has already been prepared or executed and cannot be deleted" ) - self.graph.delete_node(node_path) + self.graph.delete_node(node_id) def add_edge(self, edge: Edge) -> None: if not self._is_node_updatable(edge.destination.node_id): @@ -1190,63 +1128,3 @@ def delete_edge(self, edge: Edge) -> None: f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted" ) self.graph.delete_edge(edge) - - -class ExposedNodeInput(BaseModel): - node_path: str = Field(description="The node path to the node with the input") - field: str = Field(description="The field name of the input") - alias: str = Field(description="The alias of the input") - - -class ExposedNodeOutput(BaseModel): - node_path: str = Field(description="The node path to the node with the output") - field: str = Field(description="The field name of the output") - alias: str = Field(description="The alias of the output") - - -class LibraryGraph(BaseModel): - id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string) - graph: Graph = Field(description="The graph") - name: str = Field(description="The name of the graph") - description: str = Field(description="The description of the graph") - exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list) - exposed_outputs: list[ExposedNodeOutput] = Field( - description="The outputs exposed by this graph", default_factory=list - ) - - @field_validator("exposed_inputs", "exposed_outputs") - def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]): - if len(v) != len({i.alias for i in v}): - raise ValueError("Duplicate exposed alias") - return v - - @model_validator(mode="after") - def validate_exposed_nodes(cls, values): - graph = values.graph - - # Validate exposed inputs - for exposed_input in values.exposed_inputs: - if not graph.has_node(exposed_input.node_path): - raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist") - node = graph.get_node(exposed_input.node_path) - if get_input_field(node, exposed_input.field) is None: - raise ValueError( - f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}" - ) - - # Validate exposed outputs - for exposed_output in values.exposed_outputs: - if not graph.has_node(exposed_output.node_path): - raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist") - node = graph.get_node(exposed_output.node_path) - if get_output_field(node, exposed_output.field) is None: - raise ValueError( - f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}" - ) - - return values - - -GraphInvocation.model_rebuild(force=True) -Graph.model_rebuild(force=True) -GraphExecutionState.model_rebuild(force=True) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 40fc262be26..47a257ffe6c 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2011,8 +2011,9 @@ export type components = { /** * CLIP * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null */ - clip?: components["schemas"]["ClipField"] | null; + clip: components["schemas"]["ClipField"] | null; /** * type * @default clip_skip_output @@ -3264,6 +3265,7 @@ export type components = { /** * Masked Latents Name * @description The name of the masked image latents + * @default null */ masked_latents_name?: string | null; }; @@ -4211,14 +4213,14 @@ export type components = { * Nodes * @description The nodes in this graph */ - nodes?: { - [key: string]: components["schemas"]["ControlNetInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"]; + nodes: { + [key: string]: components["schemas"]["ImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["CompelInvocation"]; }; /** * Edges * @description The connections between nodes and their fields in this graph */ - edges?: components["schemas"]["Edge"][]; + edges: components["schemas"]["Edge"][]; }; /** * GraphExecutionState @@ -4249,7 +4251,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["SchedulerOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["String2Output"] | components["schemas"]["IntegerOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["IterateInvocationOutput"]; + [key: string]: components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ControlOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"]; }; /** * Errors @@ -4273,46 +4275,6 @@ export type components = { [key: string]: string[]; }; }; - /** - * GraphInvocation - * @description Execute a graph - */ - GraphInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The graph to run */ - graph?: components["schemas"]["Graph"]; - /** - * type - * @default graph - * @constant - */ - type: "graph"; - }; - /** GraphInvocationOutput */ - GraphInvocationOutput: { - /** - * type - * @default graph_output - * @constant - */ - type: "graph_output"; - }; /** * HFModelSource * @description A HuggingFace repo_id with optional variant, sub-folder and access token. @@ -6218,6 +6180,7 @@ export type components = { /** * Seed * @description Seed used to generate this latents + * @default null */ seed?: number | null; }; @@ -6631,7 +6594,10 @@ export type components = { * @description Key of model as returned by ModelRecordServiceBase.get_model() */ key: string; - /** @description Info to load submodel */ + /** + * @description Info to load submodel + * @default null + */ submodel_type?: components["schemas"]["SubModelType"] | null; /** * Weight @@ -6697,13 +6663,15 @@ export type components = { /** * UNet * @description UNet (scheduler, LoRAs) + * @default null */ - unet?: components["schemas"]["UNetField"] | null; + unet: components["schemas"]["UNetField"] | null; /** * CLIP * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null */ - clip?: components["schemas"]["ClipField"] | null; + clip: components["schemas"]["ClipField"] | null; /** * type * @default lora_loader_output @@ -7420,7 +7388,10 @@ export type components = { * @description Key of model as returned by ModelRecordServiceBase.get_model() */ key: string; - /** @description Info to load submodel */ + /** + * @description Info to load submodel + * @default null + */ submodel_type?: components["schemas"]["SubModelType"] | null; }; /** @@ -8794,18 +8765,21 @@ export type components = { /** * UNet * @description UNet (scheduler, LoRAs) + * @default null */ - unet?: components["schemas"]["UNetField"] | null; + unet: components["schemas"]["UNetField"] | null; /** * CLIP 1 * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null */ - clip?: components["schemas"]["ClipField"] | null; + clip: components["schemas"]["ClipField"] | null; /** * CLIP 2 * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null */ - clip2?: components["schemas"]["ClipField"] | null; + clip2: components["schemas"]["ClipField"] | null; /** * type * @default sdxl_lora_loader_output @@ -9202,13 +9176,15 @@ export type components = { /** * UNet * @description UNet (scheduler, LoRAs) + * @default null */ - unet?: components["schemas"]["UNetField"] | null; + unet: components["schemas"]["UNetField"] | null; /** * VAE * @description VAE + * @default null */ - vae?: components["schemas"]["VaeField"] | null; + vae: components["schemas"]["VaeField"] | null; /** * type * @default seamless_output @@ -10397,7 +10373,10 @@ export type components = { * @description Axes("x" and "y") to which apply seamless */ seamless_axes?: string[]; - /** @description FreeU configuration */ + /** + * @description FreeU configuration + * @default null + */ freeu_config?: components["schemas"]["FreeUConfig"] | null; }; /** @@ -11113,17 +11092,17 @@ export type components = { */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * VaeModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - VaeModelFormat: "checkpoint" | "diffusers"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * T2IAdapterModelFormat + * LoRAModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + LoRAModelFormat: "lycoris" | "diffusers"; /** * StableDiffusionXLModelFormat * @description An enumeration. @@ -11131,47 +11110,47 @@ export type components = { */ StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion1ModelFormat + * IPAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + IPAdapterModelFormat: "invokeai"; /** - * StableDiffusionOnnxModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + T2IAdapterModelFormat: "diffusers"; /** - * ControlNetModelFormat + * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * CLIPVisionModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + CLIPVisionModelFormat: "diffusers"; /** - * LoRAModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - LoRAModelFormat: "lycoris" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * CLIPVisionModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - CLIPVisionModelFormat: "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** - * IPAdapterModelFormat + * VaeModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + VaeModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; diff --git a/tests/aa_nodes/__init__.py b/tests/aa_nodes/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/test_graph_execution_state.py similarity index 100% rename from tests/aa_nodes/test_graph_execution_state.py rename to tests/test_graph_execution_state.py diff --git a/tests/aa_nodes/test_invoker.py b/tests/test_invoker.py similarity index 94% rename from tests/aa_nodes/test_invoker.py rename to tests/test_invoker.py index f67b5a2ac55..38fcf859a58 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/test_invoker.py @@ -23,7 +23,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID -from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation +from invokeai.app.services.shared.graph import Graph, GraphExecutionState @pytest.fixture @@ -35,17 +35,6 @@ def simple_graph(): return g -@pytest.fixture -def graph_with_subgraph(): - sub_g = Graph() - sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) - sub_g.add_node(TextToImageTestInvocation(id="2")) - sub_g.add_edge(create_edge("1", "prompt", "2", "prompt")) - g = Graph() - g.add_node(GraphInvocation(id="1", graph=sub_g)) - return g - - # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # the test invocations. diff --git a/tests/aa_nodes/test_node_graph.py b/tests/test_node_graph.py similarity index 86% rename from tests/aa_nodes/test_node_graph.py rename to tests/test_node_graph.py index 12a181f392f..87a4948af40 100644 --- a/tests/aa_nodes/test_node_graph.py +++ b/tests/test_node_graph.py @@ -8,8 +8,6 @@ invocation, invocation_output, ) -from invokeai.app.invocations.image import ShowImageInvocation -from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.primitives import ( FloatCollectionInvocation, FloatInvocation, @@ -17,13 +15,11 @@ StringInvocation, ) from invokeai.app.invocations.upscale import ESRGANInvocation -from invokeai.app.services.shared.default_graphs import create_text_to_image from invokeai.app.services.shared.graph import ( CollectInvocation, Edge, EdgeConnection, Graph, - GraphInvocation, InvalidEdgeError, IterateInvocation, NodeAlreadyInGraphError, @@ -425,21 +421,6 @@ def test_graph_invalid_if_edges_reference_missing_nodes(): assert g.is_valid() is False -def test_graph_invalid_if_subgraph_invalid(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() - - n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") - n1.graph.nodes[n1_1.id] = n1_1 - e1 = create_edge("1", "image", "2", "image") - n1.graph.edges.append(e1) - - g.nodes[n1.id] = n1 - - assert g.is_valid() is False - - def test_graph_invalid_if_has_cycle(): g = Graph() n1 = ESRGANInvocation(id="1") @@ -466,110 +447,6 @@ def test_graph_invalid_with_invalid_connection(): assert g.is_valid() is False -# TODO: Subgraph operations -def test_graph_gets_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() - - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) - - g.add_node(n1) - - result = g.get_node("1.1") - - assert result is not None - assert result.id == "1" - assert result == n1_1 - - -def test_graph_expands_subgraph(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() - - n1_1 = AddInvocation(id="1", a=1, b=2) - n1_2 = SubtractInvocation(id="2", b=3) - n1.graph.add_node(n1_1) - n1.graph.add_node(n1_2) - n1.graph.add_edge(create_edge("1", "value", "2", "a")) - - g.add_node(n1) - - n2 = AddInvocation(id="2", b=5) - g.add_node(n2) - g.add_edge(create_edge("1.2", "value", "2", "a")) - - dg = g.nx_graph_flat() - assert set(dg.nodes) == {"1.1", "1.2", "2"} - assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} - - -def test_graph_subgraph_t2i(): - g = Graph() - n1 = GraphInvocation(id="1") - - # Get text to image default graph - lg = create_text_to_image() - n1.graph = lg.graph - - g.add_node(n1) - - n2 = IntegerInvocation(id="2", value=512) - n3 = IntegerInvocation(id="3", value=256) - - g.add_node(n2) - g.add_node(n3) - - g.add_edge(create_edge("2", "value", "1.width", "value")) - g.add_edge(create_edge("3", "value", "1.height", "value")) - - n4 = ShowImageInvocation(id="4") - g.add_node(n4) - g.add_edge(create_edge("1.8", "image", "4", "image")) - - # Validate - dg = g.nx_graph_flat() - assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} - expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] - expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) - print(expected_edges) - print(list(dg.edges)) - assert set(dg.edges) == set(expected_edges) - - -def test_graph_fails_to_get_missing_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() - - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) - - g.add_node(n1) - - with pytest.raises(NodeNotFoundError): - _ = g.get_node("1.2") - - -def test_graph_fails_to_enumerate_non_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() - - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) - - g.add_node(n1) - - n2 = ESRGANInvocation(id="2") - g.add_node(n2) - - with pytest.raises(NodeNotFoundError): - _ = g.get_node("2.1") - - def test_graph_gets_networkx_graph(): g = Graph() n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") diff --git a/tests/aa_nodes/test_nodes.py b/tests/test_nodes.py similarity index 100% rename from tests/aa_nodes/test_nodes.py rename to tests/test_nodes.py diff --git a/tests/aa_nodes/test_session_queue.py b/tests/test_session_queue.py similarity index 89% rename from tests/aa_nodes/test_session_queue.py rename to tests/test_session_queue.py index b15bb9df360..bf26b9b0026 100644 --- a/tests/aa_nodes/test_session_queue.py +++ b/tests/test_session_queue.py @@ -8,11 +8,11 @@ NodeFieldValue, calc_session_count, create_session_nfv_tuples, - populate_graph, prepare_values_to_insert, ) -from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation -from tests.aa_nodes.test_nodes import PromptTestInvocation +from invokeai.app.services.shared.graph import Graph, GraphExecutionState + +from .test_nodes import PromptTestInvocation @pytest.fixture @@ -39,30 +39,6 @@ def batch_graph() -> Graph: return g -def test_populate_graph_with_subgraph(): - g1 = Graph() - g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) - g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi")) - n1 = PromptTestInvocation(id="1", prompt="Banana snake") - subgraph = Graph() - subgraph.add_node(n1) - g1.add_node(GraphInvocation(id="3", graph=subgraph)) - - nfvs = [ - NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"), - NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"), - NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"), - ] - - g2 = populate_graph(g1, nfvs) - - # do not mutate g1 - assert g1 is not g2 - assert g2.get_node("1").prompt == "Strawberry sushi" - assert g2.get_node("2").prompt == "Strawberry sunday" - assert g2.get_node("3.1").prompt == "Strawberry snake" - - def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph): b = Batch(graph=batch_graph, data=batch_data_collection, runs=2) t = list(create_session_nfv_tuples(batch=b, maximum=1000))