Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 93 additions & 36 deletions docs/contributing/MODEL_MANAGER.md
Original file line number Diff line number Diff line change
Expand Up @@ -1531,23 +1531,29 @@ Here is a typical initialization pattern:

```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.model_load import ModelLoadService
from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry

config = InvokeAIAppConfig.get_config()
store = ModelRecordServiceBase.open(config)
loader = ModelLoadService(config, store)
ram_cache = ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
)
loader = ModelLoadService(
app_config=config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry
)
```

Note that we are relying on the contents of the application
configuration to choose the implementation of
`ModelRecordServiceBase`.
### load_model(model_config, [submodel_type], [context]) -> LoadedModel

### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel

The `load_model_by_key()` method receives the unique key that
identifies the model. It loads the model into memory, gets the model
ready for use, and returns a `LoadedModel` object.
The `load_model()` method takes an `AnyModelConfig` returned by
`ModelRecordService.get_model()` and returns the corresponding loaded
model. It loads the model into memory, gets the model ready for use,
and returns a `LoadedModel` object.

The optional second argument, `subtype` is a `SubModelType` string
enum, such as "vae". It is mandatory when used with a main model, and
Expand Down Expand Up @@ -1593,25 +1599,6 @@ with model_info as vae:
- `ModelNotFoundException` -- key in database but model not found at path
- `NotImplementedException` -- the loader doesn't know how to load this type of model

### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel

This is similar to `load_model_by_key`, but instead it accepts the
combination of the model's name, type and base, which it passes to the
model record config store for retrieval. If successful, this method
returns a `LoadedModel`. It can raise the following exceptions:

```
UnknownModelException -- model with these attributes not known
NotImplementedException -- the loader doesn't know how to load this type of model
ValueError -- more than one model matches this combination of base/type/name
```

### load_model_by_config(config, [submodel], [context]) -> LoadedModel

This method takes an `AnyModelConfig` returned by
ModelRecordService.get_model() and returns the corresponding loaded
model. It may raise a `NotImplementedException`.

### Emitting model loading events

When the `context` argument is passed to `load_model_*()`, it will
Expand Down Expand Up @@ -1656,7 +1643,7 @@ onnx models.

To install a new loader, place it in
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
`ModelLoader` and use the `@AnyModelLoader.register()` decorator to
`ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to
indicate what type of models the loader can handle.

Here is a complete example from `generic_diffusers.py`, which is able
Expand All @@ -1674,12 +1661,11 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from ..load_base import AnyModelLoader
from ..load_default import ModelLoader
from .. import ModelLoader, ModelLoaderRegistry


@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models."""

Expand Down Expand Up @@ -1728,3 +1714,74 @@ model. It does whatever it needs to do to get the model into diffusers
format, and returns the Path of the resulting model. (The path should
ordinarily be the same as `output_path`.)

## The ModelManagerService object

For convenience, the API provides a `ModelManagerService` object which
gives a single point of access to the major model manager
services. This object is created at initialization time and can be
found in the global `ApiDependencies.invoker.services.model_manager`
object, or in `context.services.model_manager` from within an
invocation.

In the examples below, we have retrieved the manager using:
```
mm = ApiDependencies.invoker.services.model_manager
```

The following properties and methods will be available:

### mm.store

This retrieves the `ModelRecordService` associated with the
manager. Example:

```
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
```

### mm.install

This retrieves the `ModelInstallService` associated with the manager.
Example:

```
job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
```

### mm.load

This retrieves the `ModelLoaderService` associated with the manager. Example:

```
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
assert len(configs) > 0

loaded_model = mm.load.load_model(configs[0])
```

The model manager also offers a few convenience shortcuts for loading
models:

### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel

Same as `mm.load.load_model()`.

### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel

This accepts the combination of the model's name, type and base, which
it passes to the model record config store for retrieval. If a unique
model config is found, this method returns a `LoadedModel`. It can
raise the following exceptions:

```
UnknownModelException -- model with these attributes not known
NotImplementedException -- the loader doesn't know how to load this type of model
ValueError -- more than one model matches this combination of base/type/name
```

### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel

This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`.
8 changes: 7 additions & 1 deletion invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
Expand Down Expand Up @@ -94,8 +96,12 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager(
app_config=configuration, db=db, download_queue=download_queue_service, events=events
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
download_queue=download_queue_service,
events=events,
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from ..dependencies import ApiDependencies

model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])


class ModelsList(BaseModel):
Expand Down Expand Up @@ -135,7 +135,7 @@ class ModelTagSet(BaseModel):
##############################################################################


@model_manager_v2_router.get(
@model_manager_router.get(
"/",
operation_id="list_model_records",
)
Expand Down Expand Up @@ -164,7 +164,7 @@ async def list_model_records(
return ModelsList(models=found_models)


@model_manager_v2_router.get(
@model_manager_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
Expand All @@ -188,7 +188,7 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))


@model_manager_v2_router.get("/summary", operation_id="list_model_summary")
@model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
Expand All @@ -200,7 +200,7 @@ async def list_model_summary(
return results


@model_manager_v2_router.get(
@model_manager_router.get(
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
Expand All @@ -223,7 +223,7 @@ async def get_model_metadata(
return result


@model_manager_v2_router.get(
@model_manager_router.get(
"/tags",
operation_id="list_tags",
)
Expand All @@ -234,7 +234,7 @@ async def list_tags() -> Set[str]:
return result


@model_manager_v2_router.get(
@model_manager_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
Expand All @@ -247,7 +247,7 @@ async def search_by_metadata_tags(
return ModelsList(models=results)


@model_manager_v2_router.patch(
@model_manager_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
Expand Down Expand Up @@ -281,7 +281,7 @@ async def update_model_record(
return model_response


@model_manager_v2_router.delete(
@model_manager_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
Expand Down Expand Up @@ -311,7 +311,7 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e))


@model_manager_v2_router.post(
@model_manager_router.post(
"/i/",
operation_id="add_model_record",
responses={
Expand Down Expand Up @@ -349,7 +349,7 @@ async def add_model_record(
return result


@model_manager_v2_router.post(
@model_manager_router.post(
"/heuristic_import",
operation_id="heuristic_import_model",
responses={
Expand Down Expand Up @@ -416,7 +416,7 @@ async def heuristic_import(
return result


@model_manager_v2_router.post(
@model_manager_router.post(
"/install",
operation_id="import_model",
responses={
Expand Down Expand Up @@ -516,7 +516,7 @@ async def import_model(
return result


@model_manager_v2_router.get(
@model_manager_router.get(
"/import",
operation_id="list_model_install_jobs",
)
Expand Down Expand Up @@ -544,7 +544,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
return jobs


@model_manager_v2_router.get(
@model_manager_router.get(
"/import/{id}",
operation_id="get_model_install_job",
responses={
Expand All @@ -564,7 +564,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
raise HTTPException(status_code=404, detail=str(e))


@model_manager_v2_router.delete(
@model_manager_router.delete(
"/import/{id}",
operation_id="cancel_model_install_job",
responses={
Expand All @@ -583,7 +583,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)


@model_manager_v2_router.patch(
@model_manager_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
Expand All @@ -597,7 +597,7 @@ async def prune_model_install_jobs() -> Response:
return Response(status_code=204)


@model_manager_v2_router.patch(
@model_manager_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
Expand All @@ -616,7 +616,7 @@ async def sync_models_to_config() -> Response:
return Response(status_code=204)


@model_manager_v2_router.put(
@model_manager_router.put(
"/convert/{key}",
operation_id="convert_model",
responses={
Expand Down Expand Up @@ -694,7 +694,7 @@ async def convert_model(
return new_config


@model_manager_v2_router.put(
@model_manager_router.put(
"/merge",
operation_id="merge",
responses={
Expand Down
Loading