Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions invokeai/app/api/routers/sessions.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from typing import Annotated, Optional, Union
from typing import Annotated, Literal, Optional, Union

from fastapi import Body, HTTPException, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic.fields import Field

from invokeai.app.services.item_storage import PaginatedResults

# Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
from ...invocations.baseinvocation import BaseInvocation
from ...invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from ...services.graph import (
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
NodeAlreadyExecutedError,
update_invocations_union,
)
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies

session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
Expand All @@ -38,6 +40,24 @@ async def create_session(
return session


@session_router.post(
"/update_nodes",
operation_id="update_nodes",
)
async def update_nodes() -> None:
class TestFromRouterOutput(BaseInvocationOutput):
type: Literal["test_from_router"] = "test_from_router"

class TestInvocationFromRouter(BaseInvocation):
type: Literal["test_from_router_output"] = "test_from_router_output"

def invoke(self, context) -> TestFromRouterOutput:
return TestFromRouterOutput()

# doesn't work from here... hmm...
update_invocations_union()


@session_router.get(
"/",
operation_id="list_sessions",
Expand Down
54 changes: 38 additions & 16 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import asyncio
import logging
import mimetypes
import socket
from inspect import signature
from pathlib import Path
from typing import Literal

import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -14,24 +17,18 @@
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from invokeai.app.services.graph import update_invocations_union

from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger

from invokeai.version.invokeai_version import __version__

# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
import mimetypes

from invokeai.version.invokeai_version import __version__
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase

import torch

# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, BaseInvocationOutput, UIConfigBase
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger

if torch.backends.mps.is_available():
# noinspection PyUnresolvedReferences
Expand Down Expand Up @@ -104,8 +101,8 @@ async def shutdown_event():
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
# if app.openapi_schema:
# return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
Expand Down Expand Up @@ -140,6 +137,9 @@ def custom_openapi():
invoker_name = invoker.__name__
output_type = signature(invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
if invoker_name not in openapi_schema["components"]["schemas"]:
openapi_schema["components"]["schemas"][invoker_name] = invoker.schema()

invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
Expand Down Expand Up @@ -218,7 +218,9 @@ def find_port(port: int):
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
from invokeai.app.util.dev_reload import start_reloader

start_reloader()

port = find_port(app_config.port)
if port != app_config.port:
Expand All @@ -242,6 +244,26 @@ def find_port(port: int):
for ch in logger.handlers:
log.addHandler(ch)

class Test1Output(BaseInvocationOutput):
type: Literal["test1_output"] = "test1_output"

class Test1Invocation(BaseInvocation):
type: Literal["test1"] = "test1"

def invoke(self, context) -> Test1Output:
return Test1Output()

class Test2Output(BaseInvocationOutput):
type: Literal["test2_output"] = "test2_output"

class TestInvocation2(BaseInvocation):
type: Literal["test2"] = "test2"

def invoke(self, context) -> Test2Output:
return Test2Output()

update_invocations_union()

loop.run_until_complete(server.serve())


Expand Down
59 changes: 56 additions & 3 deletions invokeai/app/services/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import networkx as nx
from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field
from pydantic.fields import Field, ModelField

# Importing * is bad karma but needed here for node detection
from ..invocations import * # noqa: F401 F403
Expand Down Expand Up @@ -232,7 +232,39 @@ def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore


class Graph(BaseModel):
class DynamicBaseModel(BaseModel):
"""https://github.com/pydantic/pydantic/issues/1937#issuecomment-695313040"""

@classmethod
def add_fields(cls, **field_definitions: Any):
new_fields: dict[str, ModelField] = {}
new_annotations: dict[str, Optional[type]] = {}

for f_name, f_def in field_definitions.items():
if isinstance(f_def, tuple):
try:
f_annotation, f_value = f_def
except ValueError as e:
raise Exception(
"field definitions should either be a tuple of (<type>, <default>) or just a "
"default value, unfortunately this means tuples as "
"default values are not allowed"
) from e
else:
f_annotation, f_value = None, f_def

if f_annotation:
new_annotations[f_name] = f_annotation

new_fields[f_name] = ModelField.infer(
name=f_name, value=f_value, annotation=f_annotation, class_validators=None, config=cls.__config__
)

cls.__fields__.update(new_fields)
cls.__annotations__.update(new_annotations)


class Graph(DynamicBaseModel):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
Expand Down Expand Up @@ -700,7 +732,7 @@ def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[
return g


class GraphExecutionState(BaseModel):
class GraphExecutionState(DynamicBaseModel):
"""Tracks the state of a graph execution"""

id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
Expand Down Expand Up @@ -1131,3 +1163,24 @@ def validate_exposed_nodes(cls, values):


GraphInvocation.update_forward_refs()


def update_invocations_union() -> None:
global InvocationsUnion
global InvocationOutputsUnion
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore

Graph.add_fields(
nodes=(
dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]],
Field(description="The nodes in this graph", default_factory=dict),
)
)

GraphExecutionState.add_fields(
results=(
dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]],
Field(description="The results of node executions", default_factory=dict),
)
)
31 changes: 31 additions & 0 deletions invokeai/app/util/dev_reload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import jurigged
from jurigged.codetools import ClassDefinition

from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.backend.util.logging import InvokeAILogger

logger = InvokeAILogger.getLogger(name=__name__)


def reload_nodes(path: str, codefile: jurigged.CodeFile):
"""Callback function for jurigged post-run events."""
# Things we have access to here:
# codefile.module:module - the module object associated with this file
# codefile.module_name:str - the full module name (its key in sys.modules)
# codefile.root:ModuleCode - an AST of the current source
Comment on lines +10 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I gave you a function that will run after jurigged reloads anything. I haven't yet hooked it up to any of the update functions in this PR.


# This is only reading top-level statements, not walking the whole AST, but class definition should be top-level, right?
class_names = [statement.name for statement in codefile.root.children if isinstance(statement, ClassDefinition)]
classes = [getattr(codefile.module, name) for name in class_names]
invocations = [cls for cls in classes if issubclass(cls, BaseInvocation)]
# outputs = [cls for cls in classes if issubclass(cls, BaseInvocationOutput)]

# We should assume jurigged has already replaced all references to methods of these classes,
# but it hasn't re-executed any annotations on them (like @title or @tags).
# We need to re-do anything that involved introspection like BaseInvocation.get_all_subclasses()
logger.info("File reloaded: %s contains invocation classes %s", path, invocations)


def start_reloader():
watcher = jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
watcher.postrun.register(reload_nodes, apply_history=False)