Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion invokeai/app/api/routers/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def upload_image(
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type.startswith("image"):
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")

contents = await file.read()
Expand Down
109 changes: 74 additions & 35 deletions invokeai/app/api/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@


import pathlib
from typing import List, Literal, Optional, Union
from typing import Annotated, List, Literal, Optional, Union

from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, parse_obj_as
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.exceptions import HTTPException

from invokeai.backend import BaseModelType, ModelType
Expand All @@ -23,15 +23,26 @@
models_router = APIRouter(prefix="/v1/models", tags=["models"])

UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
update_models_response_adapter = TypeAdapter(UpdateModelResponse)

ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
import_models_response_adapter = TypeAdapter(ImportModelResponse)

ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)

MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]


class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]

model_config = ConfigDict(use_enum_values=True)


models_list_adapter = TypeAdapter(ModelsList)


@models_router.get(
"/",
Expand All @@ -49,7 +60,7 @@ async def list_models(
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models = parse_obj_as(ModelsList, {"models": models_raw})
models = models_list_adapter.validate_python({"models": models_raw})
return models


Expand Down Expand Up @@ -105,19 +116,22 @@ async def update_model(
info.path = new_info.get("path")

# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.dict()
info_dict = info.model_dump()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}

ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info_dict,
)

model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
model_response = update_models_response_adapter.validate_python(model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
Expand Down Expand Up @@ -159,7 +173,8 @@ async def import_model(

try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
items_to_import=items_to_import,
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
)
info = installed_models.get(location)

Expand All @@ -171,7 +186,7 @@ async def import_model(
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, base_model=info.base_model, model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
return import_models_response_adapter.validate_python(model_raw)

except ModelNotFoundException as e:
logger.error(str(e))
Expand Down Expand Up @@ -205,13 +220,18 @@ async def add_model(

try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
info.model_name,
info.base_model,
info.model_type,
model_attributes=info.model_dump(),
)
logger.info(f"Successfully added {info.model_name}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
model_name=info.model_name,
base_model=info.base_model,
model_type=info.model_type,
)
return parse_obj_as(ImportModelResponse, model_raw)
return import_models_response_adapter.validate_python(model_raw)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
Expand All @@ -223,7 +243,10 @@ async def add_model(
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
operation_id="del_model",
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
response_model=None,
)
Expand Down Expand Up @@ -279,7 +302,7 @@ async def convert_model(
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name, base_model=base_model, model_type=model_type
)
response = parse_obj_as(ConvertModelResponse, model_raw)
response = convert_models_response_adapter.validate_python(model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e:
Expand All @@ -302,7 +325,8 @@ async def search_for_models(
) -> List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
status_code=404,
detail=f"The search path '{search_path}' does not exist or is not directory",
)
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)

Expand Down Expand Up @@ -337,6 +361,26 @@ async def sync_to_config() -> bool:
return True


# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
# TODO: After a few updates, see if it works inside the route operation handler?
class MergeModelsBody(BaseModel):
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
merged_model_name: Optional[str] = Field(description="Name of destination model")
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
force: Optional[bool] = Field(
description="Force merging of models created with different versions of diffusers",
default=False,
)

merge_dest_directory: Optional[str] = Field(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
)

model_config = ConfigDict(protected_namespaces=())


@models_router.put(
"/merge/{base_model}",
operation_id="merge_models",
Expand All @@ -349,41 +393,36 @@ async def sync_to_config() -> bool:
response_model=MergeModelResponse,
)
async def merge_models(
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model"),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(
description="Force merging of models created with different versions of diffusers", default=False
),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
logger.info(
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
)
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(
model_names,
base_model,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
model_names=body.model_names,
base_model=base_model,
merged_model_name=body.merged_model_name or "+".join(body.model_names),
alpha=body.alpha,
interp=body.interp,
force=body.force,
merge_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
result.name,
base_model=base_model,
model_type=ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
response = convert_models_response_adapter.validate_python(model_raw)
except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{body.model_names}' not found",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
3 changes: 2 additions & 1 deletion invokeai/app/api/routers/utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from fastapi import Body
Expand Down Expand Up @@ -27,6 +27,7 @@ async def parse_dynamicprompts(
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
) -> DynamicPromptsResponse:
"""Creates a batch process"""
generator: Union[RandomPromptGenerator, CombinatorialPromptGenerator]
try:
error: Optional[str] = None
if combinatorial:
Expand Down
53 changes: 26 additions & 27 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from pydantic.json_schema import models_json_schema

# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
Expand All @@ -31,7 +31,7 @@

from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
from .api.routers import app_info, board_images, boards, images, models, session_queue, utilities
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField

Expand All @@ -51,7 +51,7 @@

# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None, separate_input_output_schemas=False)

# Add event handler
event_handler_id: int = id(app)
Expand All @@ -63,18 +63,18 @@

socket_io = SocketIO(app)

app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
)


# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
)

ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)


Expand All @@ -85,12 +85,7 @@ async def shutdown_event():


# Include all routers
# TODO: REMOVE
# app.include_router(
# invocation.invocation_router,
# prefix = '/api')

app.include_router(sessions.session_router, prefix="/api")
# app.include_router(sessions.session_router, prefix="/api")

app.include_router(utilities.utilities_router, prefix="/api")

Expand All @@ -117,6 +112,7 @@ def custom_openapi():
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)

# Add all outputs
Expand All @@ -127,29 +123,32 @@ def custom_openapi():
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)

output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
for schema_key, output_schema in output_schemas["definitions"].items():
output_schema["class"] = "output"
openapi_schema["components"]["schemas"][schema_key] = output_schema

output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]

# Add Node Editor UI helper schemas
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
ui_config_schemas = models_json_schema(
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
ref_template="#/components/schemas/{model}",
)
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema

# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__
output_type = signature(invoker.invoke).return_annotation
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"

from invokeai.backend.model_management.models import get_model_config_enums

Expand All @@ -172,7 +171,7 @@ def custom_openapi():
return app.openapi_schema


app.openapi = custom_openapi
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment

# Override API doc favicons
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
Expand Down
Loading