Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 35 additions & 33 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,47 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import asyncio
import logging
import socket
from inspect import signature
from pathlib import Path

import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema

from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger

from invokeai.version.invokeai_version import __version__

import invokeai.frontend.web as web_dir
import mimetypes
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()

from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio
import logging
import mimetypes
import socket
from inspect import signature
from pathlib import Path

import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema

import torch
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.version.invokeai_version import __version__

# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, sessions
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField

if torch.backends.mps.is_available():
# noinspection PyUnresolvedReferences
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available():
# noinspection PyUnresolvedReferences
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)


app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)

# fix for windows mimetypes registry entries being borked
Expand Down
113 changes: 55 additions & 58 deletions invokeai/app/cli_app.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,64 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

import argparse
import re
import shlex
import sys
import time
from typing import Union, get_type_hints, Optional

from pydantic import BaseModel, ValidationError
from pydantic.fields import Field

# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__


from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.invocation_stats import InvocationStatsService
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage

from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.graph import (
Edge,
EdgeConnection,
GraphExecutionState,
GraphInvocation,
LibraryGraph,
are_connection_types_compatible,
)
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage

import torch
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)

if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)


# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
config = InvokeAIAppConfig.get_config()
config.parse_args()

if True: # hack to make flake8 happy with imports coming after setting up the config
import argparse
import re
import shlex
import sys
import time
from typing import Optional, Union, get_type_hints

import torch
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field

import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__

from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
from .services.events import EventServiceBase
from .services.graph import (
Edge,
EdgeConnection,
GraphExecutionState,
GraphInvocation,
LibraryGraph,
are_connection_types_compatible,
)
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage

if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)


logger = InvokeAILogger().getLogger(config=config)


Expand Down
21 changes: 20 additions & 1 deletion invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from pydantic.typing import NoArgAnyCallable
import semver

from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig

if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices

Expand Down Expand Up @@ -470,14 +472,31 @@ class BaseInvocation(ABC, BaseModel):

@classmethod
def get_all_subclasses(cls):
app_config = InvokeAIAppConfig.get_config()
subclasses = []
toprocess = [cls]
while len(toprocess) > 0:
next = toprocess.pop(0)
next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses)
return subclasses
allowed_invocations = []
for sc in subclasses:
is_in_allowlist = (
sc.__fields__.get("type").default in app_config.allow_nodes
if isinstance(app_config.allow_nodes, list)
else True
)

is_in_denylist = (
sc.__fields__.get("type").default in app_config.deny_nodes
if isinstance(app_config.deny_nodes, list)
else False
)

if is_in_allowlist and not is_in_denylist:
allowed_invocations.append(sc)
return allowed_invocations

@classmethod
def get_invocations(cls):
Expand Down
4 changes: 3 additions & 1 deletion invokeai/app/services/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class InvokeAISettings(BaseSettings):

def parse_args(self, argv: list = sys.argv[1:]):
parser = self.get_parser()
opt = parser.parse_args(argv)
opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0:
print("Unknown args:", unknown_opts)
for name in self.__fields__:
if name not in self._excluded():
value = getattr(opt, name)
Expand Down
4 changes: 4 additions & 0 deletions invokeai/app/services/config/invokeai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)

# NODES
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")

# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import { startAppListening } from '..';
export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.fulfilled,
effect: (action, { dispatch }) => {
effect: (action, { dispatch, getState }) => {
const log = logger('system');
const schemaJSON = action.payload;

log.debug({ schemaJSON }, 'Received OpenAPI schema');

const nodeTemplates = parseSchema(schemaJSON);
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(
schemaJSON,
nodesAllowlist,
nodesDenylist
);

log.debug(
{ nodeTemplates: parseify(nodeTemplates) },
Expand Down
2 changes: 2 additions & 0 deletions invokeai/frontend/web/src/app/types/invokeai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ export type AppConfig = {
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
canRestoreDeletedImagesFromBin: boolean;
nodesAllowlist: string[] | undefined;
nodesDenylist: string[] | undefined;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];
Expand Down
16 changes: 14 additions & 2 deletions invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,23 @@ const isNotInDenylist = (schema: InvocationSchemaObject) =>
!invocationDenylist.includes(schema.properties.type.default);

export const parseSchema = (
openAPI: OpenAPIV3.Document
openAPI: OpenAPIV3.Document,
nodesAllowlistExtra: string[] | undefined = undefined,
nodesDenylistExtra: string[] | undefined = undefined
): Record<string, InvocationTemplate> => {
const filteredSchemas = Object.values(openAPI.components?.schemas ?? {})
.filter(isInvocationSchemaObject)
.filter(isNotInDenylist);
.filter(isNotInDenylist)
.filter((schema) =>
nodesAllowlistExtra
? nodesAllowlistExtra.includes(schema.properties.type.default)
: true
)
.filter((schema) =>
nodesDenylistExtra
? !nodesDenylistExtra.includes(schema.properties.type.default)
: true
);

const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export const initialConfigState: AppConfig = {
'perlinNoise',
'noiseThreshold',
],
nodesAllowlist: undefined,
nodesDenylist: undefined,
canRestoreDeletedImagesFromBin: true,
sd: {
disabledControlNetModels: [],
Expand Down
16 changes: 11 additions & 5 deletions invokeai/frontend/web/src/features/system/store/systemSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions';
import { t } from 'i18next';
import { get, startCase, upperFirst } from 'lodash-es';
import { get, startCase, truncate, upperFirst } from 'lodash-es';
import { LogLevelName } from 'roarr';
import {
isAnySessionRejected,
Expand Down Expand Up @@ -357,10 +357,13 @@ export const systemSlice = createSlice({
result.data.error.detail.map((e) => {
state.toastQueue.push(
makeToast({
title: upperFirst(e.msg),
title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error',
description: `Path:
${e.loc.slice(3).join('.')}`,
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
duration,
})
);
Expand All @@ -375,7 +378,10 @@ export const systemSlice = createSlice({
makeToast({
title: t('toast.serverError'),
status: 'error',
description: get(errorDescription, 'detail', 'Unknown Error'),
description: truncate(
get(errorDescription, 'detail', 'Unknown Error'),
{ length: 128 }
),
duration,
})
);
Expand Down
Loading