diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index d5624048e81..60d77d5e582 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -1,22 +1,24 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Annotated, Optional, Union +from typing import Annotated, Literal, Optional, Union from fastapi import Body, HTTPException, Path, Query, Response from fastapi.routing import APIRouter from pydantic.fields import Field +from invokeai.app.services.item_storage import PaginatedResults + # Importing * is bad karma but needed here for node detection from ...invocations import * # noqa: F401 F403 -from ...invocations.baseinvocation import BaseInvocation +from ...invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from ...services.graph import ( Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError, + update_invocations_union, ) -from ...services.item_storage import PaginatedResults from ..dependencies import ApiDependencies session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"]) @@ -38,6 +40,24 @@ async def create_session( return session +@session_router.post( + "/update_nodes", + operation_id="update_nodes", +) +async def update_nodes() -> None: + class TestFromRouterOutput(BaseInvocationOutput): + type: Literal["test_from_router"] = "test_from_router" + + class TestInvocationFromRouter(BaseInvocation): + type: Literal["test_from_router_output"] = "test_from_router_output" + + def invoke(self, context) -> TestFromRouterOutput: + return TestFromRouterOutput() + + # doesn't work from here... hmm... + update_invocations_union() + + @session_router.get( "/", operation_id="list_sessions", diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 902af0c02cc..affdd0903b5 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,10 +1,13 @@ # Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team import asyncio import logging +import mimetypes import socket from inspect import signature from pathlib import Path +from typing import Literal +import torch import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -15,23 +18,17 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware from pydantic.schema import schema -from .services.config import InvokeAIAppConfig -from ..backend.util.logging import InvokeAILogger - -from invokeai.version.invokeai_version import __version__ - +# noinspection PyUnresolvedReferences +import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir -import mimetypes - +from invokeai.app.services.graph import update_invocations_union +from invokeai.version.invokeai_version import __version__ from .api.dependencies import ApiDependencies from .api.routers import sessions, models, images, boards, board_images, app_info from .api.sockets import SocketIO -from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase - -import torch - -# noinspection PyUnresolvedReferences -import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) +from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, BaseInvocationOutput, UIConfigBase +from .services.config import InvokeAIAppConfig +from ..backend.util.logging import InvokeAILogger if torch.backends.mps.is_available(): # noinspection PyUnresolvedReferences @@ -104,8 +101,8 @@ async def shutdown_event(): # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? def custom_openapi(): - if app.openapi_schema: - return app.openapi_schema + # if app.openapi_schema: + # return app.openapi_schema openapi_schema = get_openapi( title=app.title, description="An API for invoking AI image operations", @@ -140,6 +137,9 @@ def custom_openapi(): invoker_name = invoker.__name__ output_type = signature(invoker.invoke).return_annotation output_type_title = output_type_titles[output_type.__name__] + if invoker_name not in openapi_schema["components"]["schemas"]: + openapi_schema["components"]["schemas"][invoker_name] = invoker.schema() + invoker_schema = openapi_schema["components"]["schemas"][invoker_name] outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} invoker_schema["output"] = outputs_ref @@ -211,14 +211,14 @@ def find_port(port: int): if app_config.dev_reload: try: - import jurigged + from invokeai.app.util.dev_reload import start_reloader except ImportError as e: logger.error( 'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.', exc_info=e, ) else: - jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info) + start_reloader() port = find_port(app_config.port) if port != app_config.port: @@ -242,6 +242,26 @@ def find_port(port: int): for ch in logger.handlers: log.addHandler(ch) + class Test1Output(BaseInvocationOutput): + type: Literal["test1_output"] = "test1_output" + + class Test1Invocation(BaseInvocation): + type: Literal["test1"] = "test1" + + def invoke(self, context) -> Test1Output: + return Test1Output() + + class Test2Output(BaseInvocationOutput): + type: Literal["test2_output"] = "test2_output" + + class TestInvocation2(BaseInvocation): + type: Literal["test2"] = "test2" + + def invoke(self, context) -> Test2Output: + return Test2Output() + + update_invocations_union() + loop.run_until_complete(server.serve()) diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 51cc8a30ae6..da9ea602ef1 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -7,7 +7,7 @@ import networkx as nx from pydantic import BaseModel, root_validator, validator -from pydantic.fields import Field +from pydantic.fields import Field, ModelField # Importing * is bad karma but needed here for node detection from ..invocations import * # noqa: F401 F403 @@ -232,7 +232,39 @@ def invoke(self, context: InvocationContext) -> CollectInvocationOutput: InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore -class Graph(BaseModel): +class DynamicBaseModel(BaseModel): + """https://github.com/pydantic/pydantic/issues/1937#issuecomment-695313040""" + + @classmethod + def add_fields(cls, **field_definitions: Any): + new_fields: dict[str, ModelField] = {} + new_annotations: dict[str, Optional[type]] = {} + + for f_name, f_def in field_definitions.items(): + if isinstance(f_def, tuple): + try: + f_annotation, f_value = f_def + except ValueError as e: + raise Exception( + "field definitions should either be a tuple of (, ) or just a " + "default value, unfortunately this means tuples as " + "default values are not allowed" + ) from e + else: + f_annotation, f_value = None, f_def + + if f_annotation: + new_annotations[f_name] = f_annotation + + new_fields[f_name] = ModelField.infer( + name=f_name, value=f_value, annotation=f_annotation, class_validators=None, config=cls.__config__ + ) + + cls.__fields__.update(new_fields) + cls.__annotations__.update(new_annotations) + + +class Graph(DynamicBaseModel): id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__()) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( @@ -700,7 +732,7 @@ def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[ return g -class GraphExecutionState(BaseModel): +class GraphExecutionState(DynamicBaseModel): """Tracks the state of a graph execution""" id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__()) @@ -1131,3 +1163,24 @@ def validate_exposed_nodes(cls, values): GraphInvocation.update_forward_refs() + + +def update_invocations_union() -> None: + global InvocationsUnion + global InvocationOutputsUnion + InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore + InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore + + Graph.add_fields( + nodes=( + dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]], + Field(description="The nodes in this graph", default_factory=dict), + ) + ) + + GraphExecutionState.add_fields( + results=( + dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]], + Field(description="The results of node executions", default_factory=dict), + ) + ) diff --git a/invokeai/app/util/dev_reload.py b/invokeai/app/util/dev_reload.py new file mode 100644 index 00000000000..92c4ae0aff9 --- /dev/null +++ b/invokeai/app/util/dev_reload.py @@ -0,0 +1,31 @@ +import jurigged +from jurigged.codetools import ClassDefinition + +from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.getLogger(name=__name__) + + +def reload_nodes(path: str, codefile: jurigged.CodeFile): + """Callback function for jurigged post-run events.""" + # Things we have access to here: + # codefile.module:module - the module object associated with this file + # codefile.module_name:str - the full module name (its key in sys.modules) + # codefile.root:ModuleCode - an AST of the current source + + # This is only reading top-level statements, not walking the whole AST, but class definition should be top-level, right? + class_names = [statement.name for statement in codefile.root.children if isinstance(statement, ClassDefinition)] + classes = [getattr(codefile.module, name) for name in class_names] + invocations = [cls for cls in classes if issubclass(cls, BaseInvocation)] + # outputs = [cls for cls in classes if issubclass(cls, BaseInvocationOutput)] + + # We should assume jurigged has already replaced all references to methods of these classes, + # but it hasn't re-executed any annotations on them (like @title or @tags). + # We need to re-do anything that involved introspection like BaseInvocation.get_all_subclasses() + logger.info("File reloaded: %s contains invocation classes %s", path, invocations) + + +def start_reloader(): + watcher = jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info) + watcher.postrun.register(reload_nodes, apply_history=False)