From 9149fee12d3e0d711e8ae21ecb4727ca626e4b77 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 15 Feb 2024 22:41:29 -0500 Subject: [PATCH 1/3] Fix issues identified during PR review by RyanjDick and brandonrising - ModelMetadataStoreService is now injected into ModelRecordStoreService (these two services are really joined at the hip, and should someday be merged) - ModelRecordStoreService is now injected into ModelManagerService - Reduced timeout value for the various installer and download wait*() methods - Introduced a Mock modelmanager for testing - Removed bare print() statement with _logger in the install helper backend. - Removed unused code from model loader init file - Made `locker` a private variable in the `LoadedModel` object. - Fixed up model merge frontend (will be deprecated anyway!) --- invokeai/app/api/dependencies.py | 8 +- .../app/services/download/download_default.py | 2 +- .../invocation_stats_default.py | 2 - .../model_install/model_install_base.py | 6 +- .../model_install/model_install_default.py | 15 +- .../model_manager/model_manager_default.py | 15 +- .../app/services/model_metadata/__init__.py | 9 + .../model_metadata/metadata_store_base.py | 65 +++++ .../model_metadata/metadata_store_sql.py | 222 ++++++++++++++++++ .../model_records/model_records_base.py | 6 +- .../model_records/model_records_sql.py | 18 +- invokeai/backend/install/install_helper.py | 13 +- .../backend/model_manager/load/__init__.py | 17 -- .../backend/model_manager/load/load_base.py | 8 +- .../model_manager/load/load_default.py | 2 +- invokeai/backend/model_manager/merge.py | 5 +- .../model_manager/metadata/__init__.py | 5 +- invokeai/frontend/merge/merge_diffusers.py | 133 +++++++---- tests/aa_nodes/test_invoker.py | 3 +- .../model_records/test_model_records_sql.py | 3 +- .../model_manager_2_fixtures.py | 15 +- .../model_metadata/test_model_metadata.py | 8 +- 22 files changed, 449 insertions(+), 131 deletions(-) create mode 100644 invokeai/app/services/model_metadata/__init__.py create mode 100644 invokeai/app/services/model_metadata/metadata_store_base.py create mode 100644 invokeai/app/services/model_metadata/metadata_store_sql.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 378961a0557..8e79b26e2d9 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -28,6 +28,8 @@ from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker from ..services.model_manager.model_manager_default import ModelManagerService +from ..services.model_metadata import ModelMetadataStoreSQL +from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue @@ -94,8 +96,12 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) download_queue_service = DownloadQueueService(event_bus=events) + model_metadata_service = ModelMetadataStoreSQL(db=db) model_manager = ModelManagerService.build_model_manager( - app_config=configuration, db=db, download_queue=download_queue_service, events=events + app_config=configuration, + model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service), + download_queue=download_queue_service, + events=events, ) names = SimpleNameService() performance_statistics = InvocationStatsService() diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 6d5cedbcad8..50cac80d094 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -194,7 +194,7 @@ def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: """Block until the indicated job has reached terminal state, or when timeout limit reached.""" start = time.time() while not job.in_terminal_state: - if self._job_completed_event.wait(timeout=5): # in case we miss an event + if self._job_completed_event.wait(timeout=0.25): # in case we miss an event self._job_completed_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 6c893021de4..486a1ca5b3e 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -46,8 +46,6 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st # This is to handle case of the model manager not being initialized, which happens # during some tests. services = self._invoker.services - if services.model_manager is None or services.model_manager.load is None: - yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. self._stats[graph_execution_state_id] = GraphExecutionStats() diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 39ea8c4a0d1..2f03db0af72 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -18,7 +18,9 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..model_metadata import ModelMetadataStoreBase class InstallStatus(str, Enum): @@ -243,7 +245,7 @@ def __init__( app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: ModelMetadataStore, + metadata_store: ModelMetadataStoreBase, event_bus: Optional["EventServiceBase"] = None, ): """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 20a85a82a14..7dee8bfd8cb 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -20,7 +20,7 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker -from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL +from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, @@ -33,7 +33,6 @@ AnyModelRepoMetadata, CivitaiMetadataFetch, HuggingFaceMetadataFetch, - ModelMetadataStore, ModelMetadataWithFiles, RemoteModelFile, ) @@ -65,7 +64,6 @@ def __init__( app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: Optional[ModelMetadataStore] = None, event_bus: Optional[EventServiceBase] = None, session: Optional[Session] = None, ): @@ -93,14 +91,7 @@ def __init__( self._running = False self._session = session self._next_job_id = 0 - # There may not necessarily be a metadata store initialized - # so we create one and initialize it with the same sql database - # used by the record store service. - if metadata_store: - self._metadata_store = metadata_store - else: - assert isinstance(record_store, ModelRecordServiceSQL) - self._metadata_store = ModelMetadataStore(record_store.db) + self._metadata_store = record_store.metadata_store # for convenience @property def app_config(self) -> InvokeAIAppConfig: # noqa D102 @@ -259,7 +250,7 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa """Block until all installation jobs are done.""" start = time.time() while len(self._download_cache) > 0: - if self._downloads_changed_event.wait(timeout=5): # in case we miss an event + if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 028d4af6159..b96341be69e 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -5,7 +5,6 @@ from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -13,8 +12,7 @@ from ..events.events_base import EventServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase -from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL -from ..shared.sqlite.sqlite_database import SqliteDatabase +from ..model_records import ModelRecordServiceBase from .model_manager_base import ModelManagerServiceBase @@ -64,7 +62,7 @@ def stop(self, invoker: Invoker) -> None: def build_model_manager( cls, app_config: InvokeAIAppConfig, - db: SqliteDatabase, + model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, ) -> Self: @@ -82,19 +80,16 @@ def build_model_manager( convert_cache = ModelConvertCache( cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size ) - record_store = ModelRecordServiceSQL(db=db) loader = ModelLoadService( app_config=app_config, - record_store=record_store, + record_store=model_record_service, ram_cache=ram_cache, convert_cache=convert_cache, ) - record_store._loader = loader # yeah, there is a circular reference here installer = ModelInstallService( app_config=app_config, - record_store=record_store, + record_store=model_record_service, download_queue=download_queue, - metadata_store=ModelMetadataStore(db=db), event_bus=events, ) - return cls(store=record_store, install=installer, load=loader) + return cls(store=model_record_service, install=installer, load=loader) diff --git a/invokeai/app/services/model_metadata/__init__.py b/invokeai/app/services/model_metadata/__init__.py new file mode 100644 index 00000000000..981c96b709b --- /dev/null +++ b/invokeai/app/services/model_metadata/__init__.py @@ -0,0 +1,9 @@ +"""Init file for ModelMetadataStoreService module.""" + +from .metadata_store_base import ModelMetadataStoreBase +from .metadata_store_sql import ModelMetadataStoreSQL + +__all__ = [ + "ModelMetadataStoreBase", + "ModelMetadataStoreSQL", +] diff --git a/invokeai/app/services/model_metadata/metadata_store_base.py b/invokeai/app/services/model_metadata/metadata_store_base.py new file mode 100644 index 00000000000..e0e4381b099 --- /dev/null +++ b/invokeai/app/services/model_metadata/metadata_store_base.py @@ -0,0 +1,65 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Storage for Model Metadata +""" + +from abc import ABC, abstractmethod +from typing import List, Set, Tuple + +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + + +class ModelMetadataStoreBase(ABC): + """Store, search and fetch model metadata retrieved from remote repositories.""" + + @abstractmethod + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + + @abstractmethod + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + + @abstractmethod + def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata + """Dump out all the metadata.""" + + @abstractmethod + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + + @abstractmethod + def list_tags(self) -> Set[str]: + """Return all tags in the tags table.""" + + @abstractmethod + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + + @abstractmethod + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + + @abstractmethod + def search_by_name(self, name: str) -> Set[str]: + """ + Return the keys of models with the indicated name. + + Note that this is the name of the model given to it by + the remote source. The user may have changed the local + name. The local name will be located in the model config + record object. + """ diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py new file mode 100644 index 00000000000..afe9d2c8c69 --- /dev/null +++ b/invokeai/app/services/model_metadata/metadata_store_sql.py @@ -0,0 +1,222 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +SQL Storage for Model Metadata +""" + +import sqlite3 +from typing import List, Optional, Set, Tuple + +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException +from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase + +from .metadata_store_base import ModelMetadataStoreBase + + +class ModelMetadataStoreSQL(ModelMetadataStoreBase): + """Store, search and fetch model metadata retrieved from remote repositories.""" + + def __init__(self, db: SqliteDatabase): + """ + Initialize a new object from preexisting sqlite3 connection and threading lock objects. + + :param conn: sqlite3 connection object + :param lock: threading Lock object + """ + super().__init__() + self._db = db + self._cursor = self._db.conn.cursor() + + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + json_serialized = metadata.model_dump_json() + with self._db.lock: + try: + self._cursor.execute( + """--sql + INSERT INTO model_metadata( + id, + metadata + ) + VALUES (?,?); + """, + ( + model_key, + json_serialized, + ), + ) + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table + self._db.conn.rollback() + raise UnknownMetadataException from excp + except sqlite3.Error as excp: + self._db.conn.rollback() + raise excp + + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT metadata FROM model_metadata + WHERE id=?; + """, + (model_key,), + ) + rows = self._cursor.fetchone() + if not rows: + raise UnknownMetadataException("model metadata not found") + return ModelMetadataFetchBase.from_json(rows[0]) + + def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata + """Dump out all the metadata.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT id,metadata FROM model_metadata; + """, + (), + ) + rows = self._cursor.fetchall() + return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows] + + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + json_serialized = metadata.model_dump_json() # turn it into a json string. + with self._db.lock: + try: + self._cursor.execute( + """--sql + UPDATE model_metadata + SET + metadata=? + WHERE id=?; + """, + (json_serialized, model_key), + ) + if self._cursor.rowcount == 0: + raise UnknownMetadataException("model metadata not found") + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.Error as e: + self._db.conn.rollback() + raise e + + return self.get_metadata(model_key) + + def list_tags(self) -> Set[str]: + """Return all tags in the tags table.""" + self._cursor.execute( + """--sql + select tag_text from tags; + """ + ) + return {x[0] for x in self._cursor.fetchall()} + + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + with self._db.lock: + try: + matches: Optional[Set[str]] = None + for tag in tags: + self._cursor.execute( + """--sql + SELECT a.model_id FROM model_tags AS a, + tags AS b + WHERE a.tag_id=b.tag_id + AND b.tag_text=?; + """, + (tag,), + ) + model_keys = {x[0] for x in self._cursor.fetchall()} + if matches is None: + matches = model_keys + matches = matches.intersection(model_keys) + except sqlite3.Error as e: + raise e + return matches if matches else set() + + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE author=?; + """, + (author,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def search_by_name(self, name: str) -> Set[str]: + """ + Return the keys of models with the indicated name. + + Note that this is the name of the model given to it by + the remote source. The user may have changed the local + name. The local name will be located in the model config + record object. + """ + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE name=?; + """, + (name,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def _update_tags(self, model_key: str, tags: Set[str]) -> None: + """Update tags for the model referenced by model_key.""" + # remove previous tags from this model + self._cursor.execute( + """--sql + DELETE FROM model_tags + WHERE model_id=?; + """, + (model_key,), + ) + + for tag in tags: + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO tags ( + tag_text + ) + VALUES (?); + """, + (tag,), + ) + self._cursor.execute( + """--sql + SELECT tag_id + FROM tags + WHERE tag_text = ? + LIMIT 1; + """, + (tag,), + ) + tag_id = self._cursor.fetchone()[0] + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO model_tags ( + model_id, + tag_id + ) + VALUES (?,?); + """, + (model_key, tag_id), + ) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index b2eacc524b7..d6014db448a 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -17,7 +17,9 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..model_metadata import ModelMetadataStoreBase class DuplicateModelException(Exception): @@ -109,7 +111,7 @@ def get_model(self, key: str) -> AnyModelConfig: @property @abstractmethod - def metadata_store(self) -> ModelMetadataStore: + def metadata_store(self) -> ModelMetadataStoreBase: """Return a ModelMetadataStore initialized on the same database.""" pass diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 84a14123838..dcd1114655b 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -54,8 +54,9 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException +from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( DuplicateModelException, @@ -69,7 +70,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase): + def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. @@ -78,6 +79,7 @@ def __init__(self, db: SqliteDatabase): super().__init__() self._db = db self._cursor = db.conn.cursor() + self._metadata_store = metadata_store @property def db(self) -> SqliteDatabase: @@ -157,7 +159,7 @@ def del_model(self, key: str) -> None: self._db.conn.rollback() raise e - def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: + def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Update the model, returning the updated version. @@ -307,9 +309,9 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: return results @property - def metadata_store(self) -> ModelMetadataStore: + def metadata_store(self) -> ModelMetadataStoreBase: """Return a ModelMetadataStore initialized on the same database.""" - return ModelMetadataStore(self._db) + return self._metadata_store def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: """ @@ -330,18 +332,18 @@ def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]: :param tags: Set of tags to search for. All tags must be present. """ - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) keys = store.search_by_tag(tags) return [self.get_model(x) for x in keys] def list_tags(self) -> Set[str]: """Return a unique set of all the model tags in the metadata database.""" - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) return store.list_tags() def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: """List metadata for all models that have it.""" - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) return store.list_all_metadata() def list_models( diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 9c386c209ce..3623b623a94 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -25,6 +25,7 @@ ModelSource, URLModelSource, ) +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager import ( @@ -45,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService logger = InvokeAILogger.get_logger(config=app_config) image_files = DiskImageFileStorage(f"{app_config.output_path}/images") db = init_db(config=app_config, logger=logger, image_files=image_files) - obj: ModelRecordServiceBase = ModelRecordServiceSQL(db) + obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) return obj @@ -54,12 +55,10 @@ def initialize_installer( ) -> ModelInstallServiceBase: """Return an initialized ModelInstallService object.""" record_store = initialize_record_store(app_config) - metadata_store = record_store.metadata_store download_queue = DownloadQueueService() installer = ModelInstallService( app_config=app_config, record_store=record_store, - metadata_store=metadata_store, download_queue=download_queue, event_bus=event_bus, ) @@ -287,14 +286,14 @@ def add_or_delete(self, selections: InstallSelections) -> None: model_name=model_name, ) if len(matches) > 1: - print( - f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate." + self._logger.error( + "{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate" ) elif not matches: - print(f"{model_to_remove}: unknown model") + self._logger.error(f"{model_to_remove}: unknown model") else: for m in matches: - print(f"Deleting {m.type}:{m.name}") + self._logger.info(f"Deleting {m.type}:{m.name}") installer.delete(m.key) installer.wait_for_installs() diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index 966a739237a..a3a840b6259 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -4,10 +4,6 @@ """ from importlib import import_module from pathlib import Path -from typing import Optional - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger from .convert_cache.convert_cache_default import ModelConvertCache from .load_base import AnyModelLoader, LoadedModel @@ -19,16 +15,3 @@ import_module(f"{__package__}.model_loaders.{module}") __all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] - - -def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: - app_config = app_config or InvokeAIAppConfig.get_config() - logger = InvokeAILogger.get_logger(config=app_config) - return AnyModelLoader( - app_config=app_config, - logger=logger, - ram_cache=ModelCache( - logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size - ), - convert_cache=ModelConvertCache(app_config.models_convert_cache_path), - ) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7649dee762b..4c5e899aa3b 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -39,21 +39,21 @@ class LoadedModel: """Context manager object that mediates transfer from RAM<->VRAM.""" config: AnyModelConfig - locker: ModelLockerBase + _locker: ModelLockerBase def __enter__(self) -> AnyModel: """Context entry.""" - self.locker.lock() + self._locker.lock() return self.model def __exit__(self, *args: Any, **kwargs: Any) -> None: """Context exit.""" - self.locker.unlock() + self._locker.unlock() @property def model(self) -> AnyModel: """Return the model without locking it.""" - return self.locker.model + return self._locker.model class ModelLoaderBase(ABC): diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 1dac121a300..79c9311de1d 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -75,7 +75,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo model_path = self._convert_if_needed(model_config, model_path, submodel_type) locker = self._load_if_needed(model_config, model_path, submodel_type) - return LoadedModel(config=model_config, locker=locker) + return LoadedModel(config=model_config, _locker=locker) def _get_model_path( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 2c94af4af3b..108f1f0e6f7 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -39,10 +39,7 @@ class ModelMerger(object): def __init__(self, installer: ModelInstallServiceBase): """ - Initialize a ModelMerger object. - - :param store: Underlying storage manager for the running process. - :param config: InvokeAIAppConfig object (if not provided, default will be selected). + Initialize a ModelMerger object with the model installer. """ self._installer = installer diff --git a/invokeai/backend/model_manager/metadata/__init__.py b/invokeai/backend/model_manager/metadata/__init__.py index 672e378c7fe..a35e55f3d24 100644 --- a/invokeai/backend/model_manager/metadata/__init__.py +++ b/invokeai/backend/model_manager/metadata/__init__.py @@ -18,7 +18,7 @@ if data.allow_commercial_use: print("Commercial use of this model is allowed") """ -from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch +from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase from .metadata_base import ( AnyModelRepoMetadata, AnyModelRepoMetadataValidator, @@ -31,7 +31,6 @@ RemoteModelFile, UnknownMetadataException, ) -from .metadata_store import ModelMetadataStore __all__ = [ "AnyModelRepoMetadata", @@ -42,7 +41,7 @@ "HuggingFaceMetadata", "HuggingFaceMetadataFetch", "LicenseRestrictions", - "ModelMetadataStore", + "ModelMetadataFetchBase", "BaseMetadata", "ModelMetadataWithFiles", "RemoteModelFile", diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 92b98b52f96..5484040674d 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -6,20 +6,40 @@ """ import argparse import curses +import re import sys from argparse import Namespace from pathlib import Path -from typing import List +from typing import List, Optional, Tuple import npyscreen from npyscreen import widget -import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType +from invokeai.app.services.download import DownloadQueueService +from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage +from invokeai.app.services.model_install import ModelInstallService +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL +from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL +from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.backend.model_manager import ( + BaseModelType, + ModelFormat, + ModelType, + ModelVariantType, +) +from invokeai.backend.model_manager.merge import ModelMerger +from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox config = InvokeAIAppConfig.get_config() +logger = InvokeAILogger.get_logger() + +BASE_TYPES = [ + (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), + (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), + (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), +] def _parse_args() -> Namespace: @@ -48,7 +68,7 @@ def _parse_args() -> Namespace: parser.add_argument( "--base_model", type=str, - choices=[x.value for x in BaseModelType], + choices=[x[0].value for x in BASE_TYPES], help="The base model shared by the models to be merged", ) parser.add_argument( @@ -98,17 +118,17 @@ def __init__(self, parentApp, name): super().__init__(parentApp, name) @property - def model_manager(self): - return self.parentApp.model_manager + def record_store(self): + return self.parentApp.record_store def afterEditing(self): self.parentApp.setNextForm(None) def create(self): window_height, window_width = curses.initscr().getmaxyx() - - self.model_names = self.get_model_names() self.current_base = 0 + self.models = self.get_models(BASE_TYPES[self.current_base][0]) + self.model_names = [x[1] for x in self.models] max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width @@ -128,11 +148,7 @@ def create(self): self.nextrely += 1 self.base_select = self.add_widget_intelligent( SingleSelectColumns, - values=[ - "Models Built on SD-1.x", - "Models Built on SD-2.x", - "Models Built on SDXL", - ], + values=[x[1] for x in BASE_TYPES], value=[self.current_base], columns=4, max_height=2, @@ -263,21 +279,20 @@ def on_cancel(self): sys.exit(0) def marshall_arguments(self) -> dict: - model_names = self.model_names + model_keys = [x[0] for x in self.models] models = [ - model_names[self.model1.value[0]], - model_names[self.model2.value[0]], + model_keys[self.model1.value[0]], + model_keys[self.model2.value[0]], ] if self.model3.value[0] > 0: - models.append(model_names[self.model3.value[0] - 1]) + models.append(model_keys[self.model3.value[0] - 1]) interp = "add_difference" else: interp = self.interpolations[self.merge_method.value[0]] - bases = ["sd-1", "sd-2", "sdxl"] args = { - "model_names": models, - "base_model": BaseModelType(bases[self.base_select.value[0]]), + "model_keys": models, + "base_model": tuple(BaseModelType)[self.base_select.value[0]], "alpha": self.alpha.value, "interp": interp, "force": self.force.value, @@ -311,18 +326,18 @@ def validate_field_values(self) -> bool: else: return True - def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]: - model_names = [ - info["model_name"] - for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) - if info["model_format"] == "diffusers" + def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name + models = [ + (x.key, x.name) + for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model) + if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal") ] - return sorted(model_names) + return sorted(models, key=lambda x: x[1]) - def _populate_models(self, value=None): - bases = ["sd-1", "sd-2", "sdxl"] - base_model = BaseModelType(bases[value[0]]) - self.model_names = self.get_model_names(base_model) + def _populate_models(self, value: List[int]): + base_model = BASE_TYPES[value[0]][0] + self.models = self.get_models(base_model) + self.model_names = [x[1] for x in self.models] models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") @@ -334,24 +349,24 @@ def _populate_models(self, value=None): class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self, model_manager: ModelManager): + def __init__(self, record_store: ModelRecordServiceBase): super().__init__() - self.model_manager = model_manager + self.record_store = record_store def onStart(self): npyscreen.setTheme(npyscreen.Themes.ElegantTheme) self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") -def run_gui(args: Namespace): - model_manager = ModelManager(config.model_conf_path) - mergeapp = Mergeapp(model_manager) +def run_gui(args: Namespace) -> None: + record_store: ModelRecordServiceBase = get_config_store() + mergeapp = Mergeapp(record_store) mergeapp.run() - args = mergeapp.merge_arguments - merger = ModelMerger(model_manager) + merger = get_model_merger(record_store) merger.merge_diffusion_models_and_save(**args) - logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') + merged_model_name = args["merged_model_name"] + logger.info(f'Models merged into new model: "{merged_model_name}".') def run_cli(args: Namespace): @@ -364,20 +379,54 @@ def run_cli(args: Namespace): args.merged_model_name = "+".join(args.model_names) logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - model_manager = ModelManager(config.model_conf_path) + record_store: ModelRecordServiceBase = get_config_store() assert ( - not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber + len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - merger = ModelMerger(model_manager) - merger.merge_diffusion_models_and_save(**vars(args)) + merger = get_model_merger(record_store) + model_keys = [] + for name in args.model_names: + if len(name) == 32 and re.match(r"^[0-9a-f]$", name): + model_keys.append(name) + else: + models = record_store.search_by_attr( + model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) + ) + assert len(models) > 0, f"{name}: Unknown model" + assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." + model_keys.append(models[0].key) + + merger.merge_diffusion_models_and_save( + alpha=args.alpha, + model_keys=model_keys, + merged_model_name=args.merged_model_name, + interp=args.interp, + force=args.force, + ) logger.info(f'Models merged into new model: "{args.merged_model_name}".') +def get_config_store() -> ModelRecordServiceSQL: + output_path = config.output_path + assert output_path is not None + image_files = DiskImageFileStorage(output_path / "images") + db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files) + return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + + +def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger: + installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService()) + installer.start() + return ModelMerger(installer) + + def main(): args = _parse_args() if args.root_dir: config.parse_args(["--root", str(args.root_dir)]) + else: + config.parse_args([]) try: if args.front_end: diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 774f7501dc2..f67b5a2ac55 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -1,4 +1,5 @@ import logging +from unittest.mock import Mock import pytest @@ -64,7 +65,7 @@ def mock_services() -> InvocationServices: images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), logger=logging, # type: ignore - model_manager=None, # type: ignore + model_manager=Mock(), # type: ignore download_queue=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 46afe0105b5..852e1da979c 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -8,6 +8,7 @@ import pytest from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ( DuplicateModelException, ModelRecordOrderBy, @@ -36,7 +37,7 @@ def store( config = InvokeAIAppConfig(root=datadir) logger = InvokeAILogger.get_logger(config=config) db = create_mock_sqlite_database(config, logger) - return ModelRecordServiceSQL(db) + return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) def example_config() -> TextualInversionConfig: diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager_2/model_manager_2_fixtures.py index d85eab67dd3..ebdc9cb5cd6 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager_2/model_manager_2_fixtures.py @@ -14,6 +14,7 @@ from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase +from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( BaseModelType, @@ -21,7 +22,6 @@ ModelType, ) from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager_2.model_metadata.metadata_examples import ( RepoCivitaiModelMetadata1, @@ -104,7 +104,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordS def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) - store = ModelRecordServiceSQL(db) + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) # add five simple config records to the database raw1 = { "path": "/tmp/foo1", @@ -163,15 +163,14 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL @pytest.fixture -def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore: - db = mm2_record_store._db # to ensure we are sharing the same database - return ModelMetadataStore(db) +def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: + return mm2_record_store.metadata_store @pytest.fixture def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: """This fixtures defines a series of mock URLs for testing download and installation.""" - sess = TestSession() + sess: Session = TestSession() sess.mount( "https://test.com/missing_model.safetensors", TestAdapter( @@ -258,8 +257,7 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) events = DummyEventService() - store = ModelRecordServiceSQL(db) - metadata_store = ModelMetadataStore(db) + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) download_queue = DownloadQueueService(requests_session=mm2_session) download_queue.start() @@ -268,7 +266,6 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo app_config=mm2_app_config, record_store=store, download_queue=download_queue, - metadata_store=metadata_store, event_bus=events, session=mm2_session, ) diff --git a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py b/tests/backend/model_manager_2/model_metadata/test_model_metadata.py index 5a2ec937673..f61eab1b5d0 100644 --- a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py +++ b/tests/backend/model_manager_2/model_metadata/test_model_metadata.py @@ -8,6 +8,7 @@ from pydantic.networks import HttpUrl from requests.sessions import Session +from invokeai.app.services.model_metadata import ModelMetadataStoreBase from invokeai.backend.model_manager.config import ModelRepoVariant from invokeai.backend.model_manager.metadata import ( CivitaiMetadata, @@ -15,14 +16,13 @@ CommercialUsage, HuggingFaceMetadata, HuggingFaceMetadataFetch, - ModelMetadataStore, UnknownMetadataException, ) from invokeai.backend.model_manager.util import select_hf_files from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 -def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None: tags = {"text-to-image", "diffusers"} input_metadata = HuggingFaceMetadata( name="sdxl-vae", @@ -40,7 +40,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: assert mm2_metadata_store.list_tags() == tags -def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None: input_metadata = HuggingFaceMetadata( name="sdxl-vae", author="stabilityai", @@ -57,7 +57,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: assert input_metadata == output_metadata -def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None: metadata1 = HuggingFaceMetadata( name="sdxl-vae", author="stabilityai", From 191271d56295b9676b030e1068358ea85b8033e6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 17 Feb 2024 11:45:32 -0500 Subject: [PATCH 2/3] Tidy names and locations of modules - Rename old "model_management" directory to "model_management_OLD" in order to catch dangling references to original model manager. - Caught and fixed most dangling references (still checking) - Rename lora, textual_inversion and model_patcher modules - Introduce a RawModel base class to simplfy the Union returned by the model loaders. - Tidy up the model manager 2-related tests. Add useful fixtures, and a finalizer to the queue and installer fixtures that will stop the services and release threads. --- invokeai/app/invocations/compel.py | 6 +- invokeai/app/invocations/latent.py | 4 +- .../services/model_load/model_load_default.py | 23 +- .../app/services/model_manager/__init__.py | 3 +- invokeai/backend/ip_adapter/ip_adapter.py | 3 +- invokeai/backend/{embeddings => }/lora.py | 7 +- .../README.md | 0 .../__init__.py | 0 .../convert_ckpt_to_diffusers.py | 0 .../detect_baked_in_vae.py | 0 .../lora.py | 0 .../memory_snapshot.py | 0 .../model_cache.py | 0 .../model_load_optimizations.py | 0 .../model_manager.py | 0 .../model_merge.py | 0 .../model_probe.py | 0 .../model_search.py | 0 .../models/__init__.py | 0 .../models/base.py | 0 .../models/clip_vision.py | 0 .../models/controlnet.py | 0 .../models/ip_adapter.py | 0 .../models/lora.py | 0 .../models/sdxl.py | 0 .../models/stable_diffusion.py | 0 .../models/stable_diffusion_onnx.py | 0 .../models/t2i_adapter.py | 0 .../models/textual_inversion.py | 0 .../models/vae.py | 0 .../seamless.py | 0 .../util.py | 0 invokeai/backend/model_manager/config.py | 9 +- .../libc_util.py | 0 .../model_manager/load/memory_snapshot.py | 4 +- .../model_manager/load/model_loaders/lora.py | 2 +- .../load/model_loaders/textual_inversion.py | 2 +- invokeai/backend/model_manager/probe.py | 9 +- .../backend/model_manager/util/libc_util.py | 75 ++ .../backend/model_manager/util/model_util.py | 129 +++ .../backend/{embeddings => }/model_patcher.py | 0 invokeai/backend/onnx/onnx_runtime.py | 1 + invokeai/backend/raw_model.py | 14 + .../{embeddings => }/textual_inversion.py | 6 +- invokeai/backend/util/test_utils.py | 45 +- invokeai/configs/INITIAL_MODELS.yaml.OLD | 153 ---- invokeai/configs/models.yaml.example | 47 - .../frontend/install/model_install.py.OLD | 845 ------------------ .../frontend/merge/merge_diffusers.py.OLD | 438 --------- .../model_install/test_model_install.py | 2 +- .../model_records/test_model_records_sql.py | 2 +- tests/backend/ip_adapter/test_ip_adapter.py | 2 +- .../data/invokeai_root/README | 0 .../stable-diffusion/v1-inference.yaml | 0 .../data/invokeai_root/databases/README | 0 .../data/invokeai_root/models/README | 0 .../test-diffusers-main/model_index.json | 0 .../scheduler/scheduler_config.json | 0 .../text_encoder/config.json | 0 .../text_encoder/model.fp16.safetensors | 0 .../text_encoder/model.safetensors | 0 .../text_encoder_2/config.json | 0 .../text_encoder_2/model.fp16.safetensors | 0 .../text_encoder_2/model.safetensors | 0 .../test-diffusers-main/tokenizer/merges.txt | 0 .../tokenizer/special_tokens_map.json | 0 .../tokenizer/tokenizer_config.json | 0 .../test-diffusers-main/tokenizer/vocab.json | 0 .../tokenizer_2/merges.txt | 0 .../tokenizer_2/special_tokens_map.json | 0 .../tokenizer_2/tokenizer_config.json | 0 .../tokenizer_2/vocab.json | 0 .../test-diffusers-main/unet/config.json | 0 .../diffusion_pytorch_model.fp16.safetensors | 0 .../unet/diffusion_pytorch_model.safetensors | 0 .../test-diffusers-main/vae/config.json | 0 .../diffusion_pytorch_model.fp16.safetensors | 0 .../vae/diffusion_pytorch_model.safetensors | 0 .../test_files/test_embedding.safetensors | Bin .../model_loading/test_model_load.py | 11 +- .../model_manager_fixtures.py} | 101 ++- .../model_metadata/metadata_examples.py | 0 .../model_metadata/test_model_metadata.py | 2 +- .../test_libc_util.py | 2 +- .../test_lora.py | 4 +- .../test_memory_snapshot.py | 6 +- .../test_model_load_optimization.py | 2 +- .../util/test_hf_model_select.py | 0 tests/conftest.py | 5 - 89 files changed, 355 insertions(+), 1609 deletions(-) rename invokeai/backend/{embeddings => }/lora.py (99%) rename invokeai/backend/{model_management => model_management_OLD}/README.md (100%) rename invokeai/backend/{model_management => model_management_OLD}/__init__.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/convert_ckpt_to_diffusers.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/detect_baked_in_vae.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/lora.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/memory_snapshot.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_cache.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_load_optimizations.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_manager.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_merge.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_probe.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_search.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/__init__.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/base.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/clip_vision.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/controlnet.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/ip_adapter.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/lora.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/sdxl.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/stable_diffusion.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/stable_diffusion_onnx.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/t2i_adapter.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/textual_inversion.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/vae.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/seamless.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/util.py (100%) rename invokeai/backend/{model_management => model_manager}/libc_util.py (100%) create mode 100644 invokeai/backend/model_manager/util/libc_util.py create mode 100644 invokeai/backend/model_manager/util/model_util.py rename invokeai/backend/{embeddings => }/model_patcher.py (100%) create mode 100644 invokeai/backend/raw_model.py rename invokeai/backend/{embeddings => }/textual_inversion.py (97%) delete mode 100644 invokeai/configs/INITIAL_MODELS.yaml.OLD delete mode 100644 invokeai/configs/models.yaml.example delete mode 100644 invokeai/frontend/install/model_install.py.OLD delete mode 100644 invokeai/frontend/merge/merge_diffusers.py.OLD rename tests/backend/{model_manager_2 => model_manager}/data/invokeai_root/README (100%) rename tests/backend/{model_manager_2 => model_manager}/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml (100%) rename tests/backend/{model_manager_2 => model_manager}/data/invokeai_root/databases/README (100%) rename tests/backend/{model_manager_2 => model_manager}/data/invokeai_root/models/README (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/model_index.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/scheduler/scheduler_config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder/config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder/model.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder_2/config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer/merges.txt (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer/vocab.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer_2/merges.txt (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer_2/vocab.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/unet/config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/vae/config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test_embedding.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/model_loading/test_model_load.py (61%) rename tests/backend/{model_manager_2/model_manager_2_fixtures.py => model_manager/model_manager_fixtures.py} (80%) rename tests/backend/{model_manager_2 => model_manager}/model_metadata/metadata_examples.py (100%) rename tests/backend/{model_manager_2 => model_manager}/model_metadata/test_model_metadata.py (99%) rename tests/backend/{model_management => model_manager}/test_libc_util.py (88%) rename tests/backend/{model_management => model_manager}/test_lora.py (96%) rename tests/backend/{model_management => model_manager}/test_memory_snapshot.py (87%) rename tests/backend/{model_management => model_manager}/test_model_load_optimization.py (96%) rename tests/backend/{model_manager_2 => model_manager}/util/test_hf_model_select.py (100%) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 5159d5b89c5..593121ba60b 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -17,9 +17,9 @@ from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt -from invokeai.backend.embeddings.lora import LoRAModelRaw -from invokeai.backend.embeddings.model_patcher import ModelPatcher -from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ModelType from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 1f21b539dc9..bfe7255b628 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -50,10 +50,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.backend.embeddings.lora import LoRAModelRaw -from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus +from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType, LoadedModel +from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.util.silence_warnings import SilenceWarnings diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 29b297c8145..fa96a4672d1 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -21,11 +21,11 @@ class ModelLoadService(ModelLoadServiceBase): """Wrapper around AnyModelLoader.""" def __init__( - self, - app_config: InvokeAIAppConfig, - record_store: ModelRecordServiceBase, - ram_cache: Optional[ModelCacheBase[AnyModel]] = None, - convert_cache: Optional[ModelConvertCacheBase] = None, + self, + app_config: InvokeAIAppConfig, + record_store: ModelRecordServiceBase, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, ): """Initialize the model load service.""" logger = InvokeAILogger.get_logger(self.__class__.__name__) @@ -34,17 +34,8 @@ def __init__( self._any_loader = AnyModelLoader( app_config=app_config, logger=logger, - ram_cache=ram_cache - or ModelCache( - max_cache_size=app_config.ram_cache_size, - max_vram_cache_size=app_config.vram_cache_size, - logger=logger, - ), - convert_cache=convert_cache - or ModelConvertCache( - cache_path=app_config.models_convert_cache_path, - max_size=app_config.convert_cache_size, - ), + ram_cache=ram_cache, + convert_cache=convert_cache, ) def start(self, invoker: Invoker) -> None: diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index 5e281922a8b..66707493f71 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -3,9 +3,10 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel -from .model_manager_default import ModelManagerService +from .model_manager_default import ModelManagerServiceBase, ModelManagerService __all__ = [ + "ModelManagerServiceBase", "ModelManagerService", "AnyModel", "AnyModelConfig", diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index b4706ea99c0..3ba6fc5a23c 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -10,6 +10,7 @@ from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights from .resampler import Resampler +from ..raw_model import RawModel class ImageProjModel(torch.nn.Module): @@ -91,7 +92,7 @@ def forward(self, image_embeds): return clip_extra_context_tokens -class IPAdapter: +class IPAdapter(RawModel): """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" def __init__( diff --git a/invokeai/backend/embeddings/lora.py b/invokeai/backend/lora.py similarity index 99% rename from invokeai/backend/embeddings/lora.py rename to invokeai/backend/lora.py index 3c7ef074efe..fb0c23067fb 100644 --- a/invokeai/backend/embeddings/lora.py +++ b/invokeai/backend/lora.py @@ -10,8 +10,7 @@ from typing_extensions import Self from invokeai.backend.model_manager import BaseModelType - -from .embedding_base import EmbeddingModelRaw +from .raw_model import RawModel class LoRALayerBase: @@ -367,9 +366,7 @@ def to( AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] - -# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix -class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module): +class LoRAModelRaw(RawModel): # (torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] diff --git a/invokeai/backend/model_management/README.md b/invokeai/backend/model_management_OLD/README.md similarity index 100% rename from invokeai/backend/model_management/README.md rename to invokeai/backend/model_management_OLD/README.md diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management_OLD/__init__.py similarity index 100% rename from invokeai/backend/model_management/__init__.py rename to invokeai/backend/model_management_OLD/__init__.py diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py similarity index 100% rename from invokeai/backend/model_management/convert_ckpt_to_diffusers.py rename to invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py diff --git a/invokeai/backend/model_management/detect_baked_in_vae.py b/invokeai/backend/model_management_OLD/detect_baked_in_vae.py similarity index 100% rename from invokeai/backend/model_management/detect_baked_in_vae.py rename to invokeai/backend/model_management_OLD/detect_baked_in_vae.py diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management_OLD/lora.py similarity index 100% rename from invokeai/backend/model_management/lora.py rename to invokeai/backend/model_management_OLD/lora.py diff --git a/invokeai/backend/model_management/memory_snapshot.py b/invokeai/backend/model_management_OLD/memory_snapshot.py similarity index 100% rename from invokeai/backend/model_management/memory_snapshot.py rename to invokeai/backend/model_management_OLD/memory_snapshot.py diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management_OLD/model_cache.py similarity index 100% rename from invokeai/backend/model_management/model_cache.py rename to invokeai/backend/model_management_OLD/model_cache.py diff --git a/invokeai/backend/model_management/model_load_optimizations.py b/invokeai/backend/model_management_OLD/model_load_optimizations.py similarity index 100% rename from invokeai/backend/model_management/model_load_optimizations.py rename to invokeai/backend/model_management_OLD/model_load_optimizations.py diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management_OLD/model_manager.py similarity index 100% rename from invokeai/backend/model_management/model_manager.py rename to invokeai/backend/model_management_OLD/model_manager.py diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management_OLD/model_merge.py similarity index 100% rename from invokeai/backend/model_management/model_merge.py rename to invokeai/backend/model_management_OLD/model_merge.py diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management_OLD/model_probe.py similarity index 100% rename from invokeai/backend/model_management/model_probe.py rename to invokeai/backend/model_management_OLD/model_probe.py diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management_OLD/model_search.py similarity index 100% rename from invokeai/backend/model_management/model_search.py rename to invokeai/backend/model_management_OLD/model_search.py diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management_OLD/models/__init__.py similarity index 100% rename from invokeai/backend/model_management/models/__init__.py rename to invokeai/backend/model_management_OLD/models/__init__.py diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management_OLD/models/base.py similarity index 100% rename from invokeai/backend/model_management/models/base.py rename to invokeai/backend/model_management_OLD/models/base.py diff --git a/invokeai/backend/model_management/models/clip_vision.py b/invokeai/backend/model_management_OLD/models/clip_vision.py similarity index 100% rename from invokeai/backend/model_management/models/clip_vision.py rename to invokeai/backend/model_management_OLD/models/clip_vision.py diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management_OLD/models/controlnet.py similarity index 100% rename from invokeai/backend/model_management/models/controlnet.py rename to invokeai/backend/model_management_OLD/models/controlnet.py diff --git a/invokeai/backend/model_management/models/ip_adapter.py b/invokeai/backend/model_management_OLD/models/ip_adapter.py similarity index 100% rename from invokeai/backend/model_management/models/ip_adapter.py rename to invokeai/backend/model_management_OLD/models/ip_adapter.py diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management_OLD/models/lora.py similarity index 100% rename from invokeai/backend/model_management/models/lora.py rename to invokeai/backend/model_management_OLD/models/lora.py diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_management_OLD/models/sdxl.py similarity index 100% rename from invokeai/backend/model_management/models/sdxl.py rename to invokeai/backend/model_management_OLD/models/sdxl.py diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management_OLD/models/stable_diffusion.py similarity index 100% rename from invokeai/backend/model_management/models/stable_diffusion.py rename to invokeai/backend/model_management_OLD/models/stable_diffusion.py diff --git a/invokeai/backend/model_management/models/stable_diffusion_onnx.py b/invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py similarity index 100% rename from invokeai/backend/model_management/models/stable_diffusion_onnx.py rename to invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py diff --git a/invokeai/backend/model_management/models/t2i_adapter.py b/invokeai/backend/model_management_OLD/models/t2i_adapter.py similarity index 100% rename from invokeai/backend/model_management/models/t2i_adapter.py rename to invokeai/backend/model_management_OLD/models/t2i_adapter.py diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management_OLD/models/textual_inversion.py similarity index 100% rename from invokeai/backend/model_management/models/textual_inversion.py rename to invokeai/backend/model_management_OLD/models/textual_inversion.py diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management_OLD/models/vae.py similarity index 100% rename from invokeai/backend/model_management/models/vae.py rename to invokeai/backend/model_management_OLD/models/vae.py diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management_OLD/seamless.py similarity index 100% rename from invokeai/backend/model_management/seamless.py rename to invokeai/backend/model_management_OLD/seamless.py diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management_OLD/util.py similarity index 100% rename from invokeai/backend/model_management/util.py rename to invokeai/backend/model_management_OLD/util.py diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 42921f0b32c..bc4848b0a50 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -28,12 +28,11 @@ from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict -from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from ..raw_model import RawModel -from ..embeddings.embedding_base import EmbeddingModelRaw -from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus - -AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw] +# ModelMixin is the base class for all diffusers and transformers models +# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime +AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] class InvalidModelConfigException(Exception): diff --git a/invokeai/backend/model_management/libc_util.py b/invokeai/backend/model_manager/libc_util.py similarity index 100% rename from invokeai/backend/model_management/libc_util.py rename to invokeai/backend/model_manager/libc_util.py diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 346f5dc4247..209d7166f36 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -5,7 +5,7 @@ import torch from typing_extensions import Self -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 +from ..util.libc_util import LibcUtil, Struct_mallinfo2 GB = 2**30 # 1 GB @@ -97,4 +97,4 @@ def get_msg_line(prefix: str, val1: int, val2: int) -> str: if snapshot_1.vram is not None and snapshot_2.vram is not None: msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - return "\n" + msg if len(msg) > 0 else msg + return msg diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index d8e5f920e24..6ff2dcc9182 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -7,7 +7,7 @@ from typing import Optional, Tuple from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.embeddings.lora import LoRAModelRaw +from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 6635f6b43fe..94767479609 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional, Tuple -from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 2c2066d7c52..d511ffa875f 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -8,9 +8,7 @@ from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger -from invokeai.backend.model_management.models.base import read_checkpoint_meta -from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat -from invokeai.backend.model_management.util import lora_token_vector_length +from .util.model_util import lora_token_vector_length, read_checkpoint_meta from invokeai.backend.util.util import SilenceWarnings from .config import ( @@ -55,7 +53,6 @@ }, } - class ProbeBase(object): """Base class for probes.""" @@ -653,8 +650,8 @@ def get_base_type(self) -> BaseModelType: class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> IPAdapterModelFormat: - return IPAdapterModelFormat.InvokeAI.value + def get_format(self) -> ModelFormat: + return ModelFormat.InvokeAI def get_base_type(self) -> BaseModelType: model_file = self.model_path / "ip_adapter.bin" diff --git a/invokeai/backend/model_manager/util/libc_util.py b/invokeai/backend/model_manager/util/libc_util.py new file mode 100644 index 00000000000..1fbcae0a93c --- /dev/null +++ b/invokeai/backend/model_manager/util/libc_util.py @@ -0,0 +1,75 @@ +import ctypes + + +class Struct_mallinfo2(ctypes.Structure): + """A ctypes Structure that matches the libc mallinfo2 struct. + + Docs: + - https://man7.org/linux/man-pages/man3/mallinfo.3.html + - https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html + + struct mallinfo2 { + size_t arena; /* Non-mmapped space allocated (bytes) */ + size_t ordblks; /* Number of free chunks */ + size_t smblks; /* Number of free fastbin blocks */ + size_t hblks; /* Number of mmapped regions */ + size_t hblkhd; /* Space allocated in mmapped regions (bytes) */ + size_t usmblks; /* See below */ + size_t fsmblks; /* Space in freed fastbin blocks (bytes) */ + size_t uordblks; /* Total allocated space (bytes) */ + size_t fordblks; /* Total free space (bytes) */ + size_t keepcost; /* Top-most, releasable space (bytes) */ + }; + """ + + _fields_ = [ + ("arena", ctypes.c_size_t), + ("ordblks", ctypes.c_size_t), + ("smblks", ctypes.c_size_t), + ("hblks", ctypes.c_size_t), + ("hblkhd", ctypes.c_size_t), + ("usmblks", ctypes.c_size_t), + ("fsmblks", ctypes.c_size_t), + ("uordblks", ctypes.c_size_t), + ("fordblks", ctypes.c_size_t), + ("keepcost", ctypes.c_size_t), + ] + + def __str__(self): + s = "" + s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n" + s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n" + s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n" + s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n" + s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n" + s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n" + s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n" + s += ( + f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)" + " (GB)\n" + ) + s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n" + s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n" + return s + + +class LibcUtil: + """A utility class for interacting with the C Standard Library (`libc`) via ctypes. + + Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where + this shared library is not available. + + TODO: Improve cross-OS compatibility of this class. + """ + + def __init__(self): + self._libc = ctypes.cdll.LoadLibrary("libc.so.6") + + def mallinfo2(self) -> Struct_mallinfo2: + """Calls `libc` `mallinfo2`. + + Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html + """ + mallinfo2 = self._libc.mallinfo2 + mallinfo2.restype = Struct_mallinfo2 + return mallinfo2() diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py new file mode 100644 index 00000000000..6847a40878c --- /dev/null +++ b/invokeai/backend/model_manager/util/model_util.py @@ -0,0 +1,129 @@ +"""Utilities for parsing model files, used mostly by probe.py""" + +import json +import torch +from typing import Union +from pathlib import Path +from picklescan.scanner import scan_file_path + +def _fast_safetensors_reader(path: str): + checkpoint = {} + device = torch.device("meta") + with open(path, "rb") as f: + definition_len = int.from_bytes(f.read(8), "little") + definition_json = f.read(definition_len) + definition = json.loads(definition_json) + + if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { + "pt", + "torch", + "pytorch", + }: + raise Exception("Supported only pytorch safetensors files") + definition.pop("__metadata__", None) + + for key, info in definition.items(): + dtype = { + "I8": torch.int8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + }[info["dtype"]] + + checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) + + return checkpoint + +def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): + if str(path).endswith(".safetensors"): + try: + checkpoint = _fast_safetensors_reader(path) + except Exception: + # TODO: create issue for support "meta"? + checkpoint = safetensors.torch.load_file(path, device="cpu") + else: + if scan: + scan_result = scan_file_path(path) + if scan_result.infected_files != 0: + raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') + checkpoint = torch.load(path, map_location=torch.device("meta")) + return checkpoint + +def lora_token_vector_length(checkpoint: dict) -> int: + """ + Given a checkpoint in memory, return the lora token vector length + + :param checkpoint: The checkpoint + """ + + def _get_shape_1(key: str, tensor, checkpoint) -> int: + lora_token_vector_length = None + + if "." not in key: + return lora_token_vector_length # wrong key format + model_key, lora_key = key.split(".", 1) + + # check lora/locon + if lora_key == "lora_down.weight": + lora_token_vector_length = tensor.shape[1] + + # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) + elif lora_key in ["hada_w1_b", "hada_w2_b"]: + lora_token_vector_length = tensor.shape[1] + + # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) + elif "lokr_" in lora_key: + if model_key + ".lokr_w1" in checkpoint: + _lokr_w1 = checkpoint[model_key + ".lokr_w1"] + elif model_key + "lokr_w1_b" in checkpoint: + _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"] + else: + return lora_token_vector_length # unknown format + + if model_key + ".lokr_w2" in checkpoint: + _lokr_w2 = checkpoint[model_key + ".lokr_w2"] + elif model_key + "lokr_w2_b" in checkpoint: + _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"] + else: + return lora_token_vector_length # unknown format + + lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] + + elif lora_key == "diff": + lora_token_vector_length = tensor.shape[1] + + # ia3 can be detected only by shape[0] in text encoder + elif lora_key == "weight" and "lora_unet_" not in model_key: + lora_token_vector_length = tensor.shape[0] + + return lora_token_vector_length + + lora_token_vector_length = None + lora_te1_length = None + lora_te2_length = None + for key, tensor in checkpoint.items(): + if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_unet_") and ( + "time_emb_proj.lora_down" in key + ): # recognizes format at https://civitai.com/models/224641 + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_te") and "_self_attn_" in key: + tmp_length = _get_shape_1(key, tensor, checkpoint) + if key.startswith("lora_te_"): + lora_token_vector_length = tmp_length + elif key.startswith("lora_te1_"): + lora_te1_length = tmp_length + elif key.startswith("lora_te2_"): + lora_te2_length = tmp_length + + if lora_te1_length is not None and lora_te2_length is not None: + lora_token_vector_length = lora_te1_length + lora_te2_length + + if lora_token_vector_length is not None: + break + + return lora_token_vector_length diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/model_patcher.py similarity index 100% rename from invokeai/backend/embeddings/model_patcher.py rename to invokeai/backend/model_patcher.py diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py index f79fa015692..9b2096abdf0 100644 --- a/invokeai/backend/onnx/onnx_runtime.py +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -8,6 +8,7 @@ import onnx from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers +from ..raw_model import RawModel ONNX_WEIGHTS_NAME = "model.onnx" diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py new file mode 100644 index 00000000000..2e224d538b3 --- /dev/null +++ b/invokeai/backend/raw_model.py @@ -0,0 +1,14 @@ +"""Base class for 'Raw' models. + +The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw, +and is used for type checking of calls to the model patcher. Its main purpose +is to avoid a circular import issues when lora.py tries to import BaseModelType +from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw +from lora.py. + +The term 'raw' was introduced to describe a wrapper around a torch.nn.Module +that adds additional methods and attributes. +""" + +class RawModel: + """Base class for 'Raw' model wrappers.""" diff --git a/invokeai/backend/embeddings/textual_inversion.py b/invokeai/backend/textual_inversion.py similarity index 97% rename from invokeai/backend/embeddings/textual_inversion.py rename to invokeai/backend/textual_inversion.py index 389edff039d..9a4fa0b5402 100644 --- a/invokeai/backend/embeddings/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -8,11 +8,9 @@ from safetensors.torch import load_file from transformers import CLIPTokenizer from typing_extensions import Self +from .raw_model import RawModel -from .embedding_base import EmbeddingModelRaw - - -class TextualInversionModelRaw(EmbeddingModelRaw): +class TextualInversionModelRaw(RawModel): embedding: torch.Tensor # [n, 768]|[n, 1280] embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 685603cedc6..a3def182c8c 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -5,10 +5,9 @@ import pytest import torch -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import ModelInstall -from invokeai.backend.model_management.model_manager import LoadedModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.app.services.model_records import UnknownModelException +from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType @pytest.fixture(scope="session") @@ -16,31 +15,20 @@ def torch_device(): return "cuda" if torch.cuda.is_available() else "cpu" -@pytest.fixture(scope="module") -def model_installer(): - """A global ModelInstall pytest fixture to be used by many tests.""" - # HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions - # between tests that need to alter the config. For example, some tests change the 'root' directory in the config, - # which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround, - # we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using - # a singleton. - return ModelInstall(InvokeAIAppConfig.get_config(log_level="info")) - - def install_and_load_model( - model_installer: ModelInstall, + model_manager: ModelManagerServiceBase, model_path_id_or_url: Union[str, Path], model_name: str, base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, -) -> LoadedModelInfo: - """Install a model if it is not already installed, then get the LoadedModelInfo for that model. +) -> LoadedModel: + """Install a model if it is not already installed, then get the LoadedModel for that model. This is intended as a utility function for tests. Args: - model_installer (ModelInstall): The model installer. + mm2_model_manager (ModelManagerServiceBase): The model manager model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it is not already installed. model_name (str): The model name, forwarded to ModelManager.get_model(...). @@ -51,16 +39,23 @@ def install_and_load_model( Returns: LoadedModelInfo """ - # If the requested model is already installed, return its LoadedModelInfo. - with contextlib.suppress(ModelNotFoundException): - return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) + # If the requested model is already installed, return its LoadedModel + with contextlib.suppress(UnknownModelException): + # TODO: Replace with wrapper call + loaded_model: LoadedModel = model_manager.load.load_model_by_attr( + model_name=model_name, base_model=base_model, model_type=model_type + ) + return loaded_model # Install the requested model. - model_installer.heuristic_import(model_path_id_or_url) + job = model_manager.install.heuristic_import(model_path_id_or_url) + model_manager.install.wait_for_job(job, timeout=10) + assert job.complete try: - return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) - except ModelNotFoundException as e: + loaded_model = model_manager.load.load_model_by_config(job.config_out) + return loaded_model + except UnknownModelException as e: raise Exception( "Failed to get model info after installing it. There could be a mismatch between the requested model and" f" the installation id ('{model_path_id_or_url}'). Error: {e}" diff --git a/invokeai/configs/INITIAL_MODELS.yaml.OLD b/invokeai/configs/INITIAL_MODELS.yaml.OLD deleted file mode 100644 index c230665e3a6..00000000000 --- a/invokeai/configs/INITIAL_MODELS.yaml.OLD +++ /dev/null @@ -1,153 +0,0 @@ -# This file predefines a few models that the user may want to install. -sd-1/main/stable-diffusion-v1-5: - description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - repo_id: runwayml/stable-diffusion-v1-5 - recommended: True - default: True -sd-1/main/stable-diffusion-v1-5-inpainting: - description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - repo_id: runwayml/stable-diffusion-inpainting - recommended: True -sd-2/main/stable-diffusion-2-1: - description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-1 - recommended: False -sd-2/main/stable-diffusion-2-inpainting: - description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-inpainting - recommended: False -sdxl/main/stable-diffusion-xl-base-1-0: - description: Stable Diffusion XL base model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-base-1.0 - recommended: True -sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: - description: Stable Diffusion XL refiner model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 - recommended: False -sdxl/vae/sdxl-1-0-vae-fix: - description: Fine tuned version of the SDXL-1.0 VAE - repo_id: madebyollin/sdxl-vae-fp16-fix - recommended: True -sd-1/main/Analog-Diffusion: - description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - repo_id: wavymulder/Analog-Diffusion - recommended: False -sd-1/main/Deliberate_v5: - description: Versatile model that produces detailed images up to 768px (4.27 GB) - path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors - recommended: False -sd-1/main/Dungeons-and-Diffusion: - description: Dungeons & Dragons characters (2.13 GB) - repo_id: 0xJustin/Dungeons-and-Diffusion - recommended: False -sd-1/main/dreamlike-photoreal-2: - description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - repo_id: dreamlike-art/dreamlike-photoreal-2.0 - recommended: False -sd-1/main/Inkpunk-Diffusion: - description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - repo_id: Envvi/Inkpunk-Diffusion - recommended: False -sd-1/main/openjourney: - description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - repo_id: prompthero/openjourney - recommended: False -sd-1/main/seek.art_MEGA: - repo_id: coreco/seek.art_MEGA - description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) - recommended: False -sd-1/main/trinart_stable_diffusion_v2: - description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - repo_id: naclbit/trinart_stable_diffusion_v2 - recommended: False -sd-1/controlnet/qrcode_monster: - repo_id: monster-labs/control_v1p_sd15_qrcode_monster - subfolder: v2 -sd-1/controlnet/canny: - repo_id: lllyasviel/control_v11p_sd15_canny - recommended: True -sd-1/controlnet/inpaint: - repo_id: lllyasviel/control_v11p_sd15_inpaint -sd-1/controlnet/mlsd: - repo_id: lllyasviel/control_v11p_sd15_mlsd -sd-1/controlnet/depth: - repo_id: lllyasviel/control_v11f1p_sd15_depth - recommended: True -sd-1/controlnet/normal_bae: - repo_id: lllyasviel/control_v11p_sd15_normalbae -sd-1/controlnet/seg: - repo_id: lllyasviel/control_v11p_sd15_seg -sd-1/controlnet/lineart: - repo_id: lllyasviel/control_v11p_sd15_lineart - recommended: True -sd-1/controlnet/lineart_anime: - repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime -sd-1/controlnet/openpose: - repo_id: lllyasviel/control_v11p_sd15_openpose - recommended: True -sd-1/controlnet/scribble: - repo_id: lllyasviel/control_v11p_sd15_scribble - recommended: False -sd-1/controlnet/softedge: - repo_id: lllyasviel/control_v11p_sd15_softedge -sd-1/controlnet/shuffle: - repo_id: lllyasviel/control_v11e_sd15_shuffle -sd-1/controlnet/tile: - repo_id: lllyasviel/control_v11f1e_sd15_tile -sd-1/controlnet/ip2p: - repo_id: lllyasviel/control_v11e_sd15_ip2p -sd-1/t2i_adapter/canny-sd15: - repo_id: TencentARC/t2iadapter_canny_sd15v2 -sd-1/t2i_adapter/sketch-sd15: - repo_id: TencentARC/t2iadapter_sketch_sd15v2 -sd-1/t2i_adapter/depth-sd15: - repo_id: TencentARC/t2iadapter_depth_sd15v2 -sd-1/t2i_adapter/zoedepth-sd15: - repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 -sdxl/t2i_adapter/canny-sdxl: - repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 -sdxl/t2i_adapter/zoedepth-sdxl: - repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 -sdxl/t2i_adapter/lineart-sdxl: - repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 -sdxl/t2i_adapter/sketch-sdxl: - repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 -sd-1/embedding/EasyNegative: - path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors - recommended: True -sd-1/embedding/ahx-beta-453407d: - repo_id: sd-concepts-library/ahx-beta-453407d -sd-1/lora/Ink scenery: - path: https://civitai.com/api/download/models/83390 -sd-1/ip_adapter/ip_adapter_sd15: - repo_id: InvokeAI/ip_adapter_sd15 - recommended: True - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: IP-Adapter for SD 1.5 models -sd-1/ip_adapter/ip_adapter_plus_sd15: - repo_id: InvokeAI/ip_adapter_plus_sd15 - recommended: False - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: Refined IP-Adapter for SD 1.5 models -sd-1/ip_adapter/ip_adapter_plus_face_sd15: - repo_id: InvokeAI/ip_adapter_plus_face_sd15 - recommended: False - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: Refined IP-Adapter for SD 1.5 models, adapted for faces -sdxl/ip_adapter/ip_adapter_sdxl: - repo_id: InvokeAI/ip_adapter_sdxl - recommended: False - requires: - - InvokeAI/ip_adapter_sdxl_image_encoder - description: IP-Adapter for SDXL models -any/clip_vision/ip_adapter_sd_image_encoder: - repo_id: InvokeAI/ip_adapter_sd_image_encoder - recommended: False - description: Required model for using IP-Adapters with SD-1/2 models -any/clip_vision/ip_adapter_sdxl_image_encoder: - repo_id: InvokeAI/ip_adapter_sdxl_image_encoder - recommended: False - description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/configs/models.yaml.example b/invokeai/configs/models.yaml.example deleted file mode 100644 index 98f8f77e62c..00000000000 --- a/invokeai/configs/models.yaml.example +++ /dev/null @@ -1,47 +0,0 @@ -# This file describes the alternative machine learning models -# available to InvokeAI script. -# -# To add a new model, follow the examples below. Each -# model requires a model config file, a weights file, -# and the width and height of the images it -# was trained on. -diffusers-1.4: - description: 🤗🧨 Stable Diffusion v1.4 - format: diffusers - repo_id: CompVis/stable-diffusion-v1-4 -diffusers-1.5: - description: 🤗🧨 Stable Diffusion v1.5 - format: diffusers - repo_id: runwayml/stable-diffusion-v1-5 - default: true -diffusers-1.5+mse: - description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE - format: diffusers - repo_id: runwayml/stable-diffusion-v1-5 - vae: - repo_id: stabilityai/sd-vae-ft-mse -diffusers-inpainting-1.5: - description: 🤗🧨 inpainting for Stable Diffusion v1.5 - format: diffusers - repo_id: runwayml/stable-diffusion-inpainting -stable-diffusion-1.5: - description: The newest Stable Diffusion version 1.5 weight file (4.27 GB) - weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt - config: configs/stable-diffusion/v1-inference.yaml - width: 512 - height: 512 - vae: ./models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt -stable-diffusion-1.4: - description: Stable Diffusion inference model version 1.4 - config: configs/stable-diffusion/v1-inference.yaml - weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt - vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt - width: 512 - height: 512 -inpainting-1.5: - weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt - config: configs/stable-diffusion/v1-inpainting-inference.yaml - vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt - description: RunwayML SD 1.5 model optimized for inpainting - width: 512 - height: 512 diff --git a/invokeai/frontend/install/model_install.py.OLD b/invokeai/frontend/install/model_install.py.OLD deleted file mode 100644 index e23538ffd66..00000000000 --- a/invokeai/frontend/install/model_install.py.OLD +++ /dev/null @@ -1,845 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) -# Before running stable-diffusion on an internet-isolated machine, -# run this script from one with internet connectivity. The -# two machines must share a common .cache directory. - -""" -This is the npyscreen frontend to the model installation application. -The work is actually done in backend code in model_install_backend.py. -""" - -import argparse -import curses -import logging -import sys -import textwrap -import traceback -from argparse import Namespace -from multiprocessing import Process -from multiprocessing.connection import Connection, Pipe -from pathlib import Path -from shutil import get_terminal_size -from typing import Optional - -import npyscreen -import torch -from npyscreen import widget - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType -from invokeai.backend.model_management import ModelManager, ModelType -from invokeai.backend.util import choose_precision, choose_torch_device -from invokeai.backend.util.logging import InvokeAILogger -from invokeai.frontend.install.widgets import ( - MIN_COLS, - MIN_LINES, - BufferBox, - CenteredTitleText, - CyclingForm, - MultiSelectColumns, - SingleSelectColumns, - TextBox, - WindowTooSmallException, - select_stable_diffusion_config_file, - set_min_terminal_size, -) - -config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger() - -# build a table mapping all non-printable characters to None -# for stripping control characters -# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python -NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()} - -# maximum number of installed models we can display before overflowing vertically -MAX_OTHER_MODELS = 72 - - -def make_printable(s: str) -> str: - """Replace non-printable characters in a string""" - return s.translate(NOPRINT_TRANS_TABLE) - - -class addModelsForm(CyclingForm, npyscreen.FormMultiPage): - # for responsive resizing set to False, but this seems to cause a crash! - FIX_MINIMUM_SIZE_WHEN_CREATED = True - - # for persistence - current_tab = 0 - - def __init__(self, parentApp, name, multipage=False, *args, **keywords): - self.multipage = multipage - self.subprocess = None - super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? - - def create(self): - self.keypress_timeout = 10 - self.counter = 0 - self.subprocess_connection = None - - if not config.model_conf_path.exists(): - with open(config.model_conf_path, "w") as file: - print("# InvokeAI model configuration file", file=file) - self.installer = ModelInstall(config) - self.all_models = self.installer.all_models() - self.starter_models = self.installer.starter_models() - self.model_labels = self._get_model_labels() - window_width, window_height = get_terminal_size() - - self.nextrely -= 1 - self.add_widget_intelligent( - npyscreen.FixedText, - value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", - editable=False, - color="CAUTION", - ) - self.nextrely += 1 - self.tabs = self.add_widget_intelligent( - SingleSelectColumns, - values=[ - "STARTERS", - "MAINS", - "CONTROLNETS", - "T2I-ADAPTERS", - "IP-ADAPTERS", - "LORAS", - "TI EMBEDDINGS", - ], - value=[self.current_tab], - columns=7, - max_height=2, - relx=8, - scroll_exit=True, - ) - self.tabs.on_changed = self._toggle_tables - - top_of_table = self.nextrely - self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely - - self.nextrely = top_of_table - self.pipeline_models = self.add_pipeline_widgets( - model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models - ) - # self.pipeline_models['autoload_pending'] = True - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.controlnet_models = self.add_model_widgets( - model_type=ModelType.ControlNet, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.t2i_models = self.add_model_widgets( - model_type=ModelType.T2IAdapter, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - self.nextrely = top_of_table - self.ipadapter_models = self.add_model_widgets( - model_type=ModelType.IPAdapter, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.lora_models = self.add_model_widgets( - model_type=ModelType.Lora, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.ti_models = self.add_model_widgets( - model_type=ModelType.TextualInversion, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = bottom_of_table + 1 - - self.monitor = self.add_widget_intelligent( - BufferBox, - name="Log Messages", - editable=False, - max_height=6, - ) - - self.nextrely += 1 - done_label = "APPLY CHANGES" - back_label = "BACK" - cancel_label = "CANCEL" - current_position = self.nextrely - if self.multipage: - self.back_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=back_label, - when_pressed_function=self.on_back, - ) - else: - self.nextrely = current_position - self.cancel_button = self.add_widget_intelligent( - npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel - ) - self.nextrely = current_position - self.ok_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=done_label, - relx=(window_width - len(done_label)) // 2, - when_pressed_function=self.on_execute, - ) - - label = "APPLY CHANGES & EXIT" - self.nextrely = current_position - self.done = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=label, - relx=window_width - len(label) - 15, - when_pressed_function=self.on_done, - ) - - # This restores the selected page on return from an installation - for _i in range(1, self.current_tab + 1): - self.tabs.h_cursor_line_down(1) - self._toggle_tables([self.current_tab]) - - ############# diffusers tab ########## - def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: - """Add widgets responsible for selecting diffusers models""" - widgets = {} - models = self.all_models - starters = self.starter_models - starter_model_labels = self.model_labels - - self.installed_models = sorted([x for x in starters if models[x].installed]) - - widgets.update( - label1=self.add_widget_intelligent( - CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace.", - editable=False, - labelColor="CAUTION", - ) - ) - - self.nextrely -= 1 - # if user has already installed some initial models, then don't patronize them - # by showing more recommendations - show_recommended = len(self.installed_models) == 0 - keys = [x for x in models.keys() if x in starters] - widgets.update( - models_selected=self.add_widget_intelligent( - MultiSelectColumns, - columns=1, - name="Install Starter Models", - values=[starter_model_labels[x] for x in keys], - value=[ - keys.index(x) - for x in keys - if (show_recommended and models[x].recommended) or (x in self.installed_models) - ], - max_height=len(starters) + 1, - relx=4, - scroll_exit=True, - ), - models=keys, - ) - - self.nextrely += 1 - return widgets - - ############# Add a set of model install widgets ######## - def add_model_widgets( - self, - model_type: ModelType, - window_width: int = 120, - install_prompt: str = None, - exclude: set = None, - ) -> dict[str, npyscreen.widget]: - """Generic code to create model selection widgets""" - if exclude is None: - exclude = set() - widgets = {} - model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] - model_labels = [self.model_labels[x] for x in model_list] - - show_recommended = len(self.installed_models) == 0 - truncated = False - if len(model_list) > 0: - max_width = max([len(x) for x in model_labels]) - columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding - columns = min(len(model_list), columns) or 1 - prompt = ( - install_prompt - or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk." - ) - - widgets.update( - label1=self.add_widget_intelligent( - CenteredTitleText, - name=prompt, - editable=False, - labelColor="CAUTION", - ) - ) - - if len(model_labels) > MAX_OTHER_MODELS: - model_labels = model_labels[0:MAX_OTHER_MODELS] - truncated = True - - widgets.update( - models_selected=self.add_widget_intelligent( - MultiSelectColumns, - columns=columns, - name=f"Install {model_type} Models", - values=model_labels, - value=[ - model_list.index(x) - for x in model_list - if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed - ], - max_height=len(model_list) // columns + 1, - relx=4, - scroll_exit=True, - ), - models=model_list, - ) - - if truncated: - widgets.update( - warning_message=self.add_widget_intelligent( - npyscreen.FixedText, - value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.", - editable=False, - color="CAUTION", - ) - ) - - self.nextrely += 1 - widgets.update( - download_ids=self.add_widget_intelligent( - TextBox, - name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=4, - scroll_exit=True, - editable=True, - ) - ) - return widgets - - ### Tab for arbitrary diffusers widgets ### - def add_pipeline_widgets( - self, - model_type: ModelType = ModelType.Main, - window_width: int = 120, - **kwargs, - ) -> dict[str, npyscreen.widget]: - """Similar to add_model_widgets() but adds some additional widgets at the bottom - to support the autoload directory""" - widgets = self.add_model_widgets( - model_type=model_type, - window_width=window_width, - install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.", - **kwargs, - ) - - return widgets - - def resize(self): - super().resize() - if s := self.starter_pipelines.get("models_selected"): - keys = [x for x in self.all_models.keys() if x in self.starter_models] - s.values = [self.model_labels[x] for x in keys] - - def _toggle_tables(self, value=None): - selected_tab = value[0] - widgets = [ - self.starter_pipelines, - self.pipeline_models, - self.controlnet_models, - self.t2i_models, - self.ipadapter_models, - self.lora_models, - self.ti_models, - ] - - for group in widgets: - for _k, v in group.items(): - try: - v.hidden = True - v.editable = False - except Exception: - pass - for _k, v in widgets[selected_tab].items(): - try: - v.hidden = False - if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)): - v.editable = True - except Exception: - pass - self.__class__.current_tab = selected_tab # for persistence - self.display() - - def _get_model_labels(self) -> dict[str, str]: - window_width, window_height = get_terminal_size() - checkbox_width = 4 - spacing_width = 2 - - models = self.all_models - label_width = max([len(models[x].name) for x in models]) - description_width = window_width - label_width - checkbox_width - spacing_width - - result = {} - for x in models.keys(): - description = models[x].description - description = ( - description[0 : description_width - 3] + "..." - if description and len(description) > description_width - else description - if description - else "" - ) - result[x] = f"%-{label_width}s %s" % (models[x].name, description) - return result - - def _get_columns(self) -> int: - window_width, window_height = get_terminal_size() - cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1 - return min(cols, len(self.installed_models)) - - def confirm_deletions(self, selections: InstallSelections) -> bool: - remove_models = selections.remove_models - if len(remove_models) > 0: - mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) - return npyscreen.notify_ok_cancel( - f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" - ) - else: - return True - - def on_execute(self): - self.marshall_arguments() - app = self.parentApp - if not self.confirm_deletions(app.install_selections): - return - - self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) - self.ok_button.hidden = True - self.display() - - # TO DO: Spawn a worker thread, not a subprocess - parent_conn, child_conn = Pipe() - p = Process( - target=process_and_execute, - kwargs={ - "opt": app.program_opts, - "selections": app.install_selections, - "conn_out": child_conn, - }, - ) - p.start() - child_conn.close() - self.subprocess_connection = parent_conn - self.subprocess = p - app.install_selections = InstallSelections() - - def on_back(self): - self.parentApp.switchFormPrevious() - self.editing = False - - def on_cancel(self): - self.parentApp.setNextForm(None) - self.parentApp.user_cancelled = True - self.editing = False - - def on_done(self): - self.marshall_arguments() - if not self.confirm_deletions(self.parentApp.install_selections): - return - self.parentApp.setNextForm(None) - self.parentApp.user_cancelled = False - self.editing = False - - ########## This routine monitors the child process that is performing model installation and removal ##### - def while_waiting(self): - """Called during idle periods. Main task is to update the Log Messages box with messages - from the child process that does the actual installation/removal""" - c = self.subprocess_connection - if not c: - return - - monitor_widget = self.monitor.entry_widget - while c.poll(): - try: - data = c.recv_bytes().decode("utf-8") - data.strip("\n") - - # processing child is requesting user input to select the - # right configuration file - if data.startswith("*need v2 config"): - _, model_path, *_ = data.split(":", 2) - self._return_v2_config(model_path) - - # processing child is done - elif data == "*done*": - self._close_subprocess_and_regenerate_form() - break - - # update the log message box - else: - data = make_printable(data) - data = data.replace("[A", "") - monitor_widget.buffer( - textwrap.wrap( - data, - width=monitor_widget.width, - subsequent_indent=" ", - ), - scroll_end=True, - ) - self.display() - except (EOFError, OSError): - self.subprocess_connection = None - - def _return_v2_config(self, model_path: str): - c = self.subprocess_connection - model_name = Path(model_path).name - message = select_stable_diffusion_config_file(model_name=model_name) - c.send_bytes(message.encode("utf-8")) - - def _close_subprocess_and_regenerate_form(self): - app = self.parentApp - self.subprocess_connection.close() - self.subprocess_connection = None - self.monitor.entry_widget.buffer(["** Action Complete **"]) - self.display() - - # rebuild the form, saving and restoring some of the fields that need to be preserved. - saved_messages = self.monitor.entry_widget.values - - app.main_form = app.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - multipage=self.multipage, - ) - app.switchForm("MAIN") - - app.main_form.monitor.entry_widget.values = saved_messages - app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) - # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir - # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan - - def marshall_arguments(self): - """ - Assemble arguments and store as attributes of the application: - .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml - True => Install - False => Remove - .scan_directory: Path to a directory of models to scan and import - .autoscan_on_startup: True if invokeai should scan and import at startup time - .import_model_paths: list of URLs, repo_ids and file paths to import - """ - selections = self.parentApp.install_selections - all_models = self.all_models - - # Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove - ui_sections = [ - self.starter_pipelines, - self.pipeline_models, - self.controlnet_models, - self.t2i_models, - self.ipadapter_models, - self.lora_models, - self.ti_models, - ] - for section in ui_sections: - if "models_selected" not in section: - continue - selected = {section["models"][x] for x in section["models_selected"].value} - models_to_install = [x for x in selected if not self.all_models[x].installed] - models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] - selections.remove_models.extend(models_to_remove) - selections.install_models.extend( - all_models[x].path or all_models[x].repo_id - for x in models_to_install - if all_models[x].path or all_models[x].repo_id - ) - - # models located in the 'download_ids" section - for section in ui_sections: - if downloads := section.get("download_ids"): - selections.install_models.extend(downloads.value.split()) - - # NOT NEEDED - DONE IN BACKEND NOW - # # special case for the ipadapter_models. If any of the adapters are - # # chosen, then we add the corresponding encoder(s) to the install list. - # section = self.ipadapter_models - # if section.get("models_selected"): - # selected_adapters = [ - # self.all_models[section["models"][x]].name for x in section.get("models_selected").value - # ] - # encoders = [] - # if any(["sdxl" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sdxl_image_encoder") - # if any(["sd15" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sd_image_encoder") - # for encoder in encoders: - # key = f"any/clip_vision/{encoder}" - # repo_id = f"InvokeAI/{encoder}" - # if key not in self.all_models: - # selections.install_models.append(repo_id) - - -class AddModelApplication(npyscreen.NPSAppManaged): - def __init__(self, opt): - super().__init__() - self.program_opts = opt - self.user_cancelled = False - # self.autoload_pending = True - self.install_selections = InstallSelections() - - def onStart(self): - npyscreen.setTheme(npyscreen.Themes.DefaultTheme) - self.main_form = self.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - cycle_widgets=False, - ) - - -class StderrToMessage: - def __init__(self, connection: Connection): - self.connection = connection - - def write(self, data: str): - self.connection.send_bytes(data.encode("utf-8")) - - def flush(self): - pass - - -# -------------------------------------------------------- -def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: - if tui_conn: - logger.debug("Waiting for user response...") - return _ask_user_for_pt_tui(model_path, tui_conn) - else: - return _ask_user_for_pt_cmdline(model_path) - - -def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: - choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] - print( - f""" -Please select the scheduler prediction type of the checkpoint named {model_path.name}: -[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images -[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models -[3] Accept the best guess; you can fix it in the Web UI later -""" - ) - choice = None - ok = False - while not ok: - try: - choice = input("select [3]> ").strip() - if not choice: - return None - choice = choices[int(choice) - 1] - ok = True - except (ValueError, IndexError): - print(f"{choice} is not a valid choice") - except EOFError: - return - return choice - - -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: - tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) - # note that we don't do any status checking here - response = tui_conn.recv_bytes().decode("utf-8") - if response is None: - return None - elif response == "epsilon": - return SchedulerPredictionType.epsilon - elif response == "v": - return SchedulerPredictionType.VPrediction - elif response == "guess": - return None - else: - return None - - -# -------------------------------------------------------- -def process_and_execute( - opt: Namespace, - selections: InstallSelections, - conn_out: Connection = None, -): - # need to reinitialize config in subprocess - config = InvokeAIAppConfig.get_config() - args = ["--root", opt.root] if opt.root else [] - config.parse_args(args) - - # set up so that stderr is sent to conn_out - if conn_out: - translator = StderrToMessage(conn_out) - sys.stderr = translator - sys.stdout = translator - logger = InvokeAILogger.get_logger() - logger.handlers.clear() - logger.addHandler(logging.StreamHandler(translator)) - - installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) - installer.install(selections) - - if conn_out: - conn_out.send_bytes("*done*".encode("utf-8")) - conn_out.close() - - -# -------------------------------------------------------- -def select_and_download_models(opt: Namespace): - precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) - config.precision = precision - installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) - if opt.list_models: - installer.list_models(opt.list_models) - elif opt.add or opt.delete: - selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) - installer.install(selections) - elif opt.default_only: - selections = InstallSelections(install_models=installer.default_model()) - installer.install(selections) - elif opt.yes_to_all: - selections = InstallSelections(install_models=installer.recommended_models()) - installer.install(selections) - - # this is where the TUI is called - else: - # needed to support the probe() method running under a subprocess - torch.multiprocessing.set_start_method("spawn") - - if not set_min_terminal_size(MIN_COLS, MIN_LINES): - raise WindowTooSmallException( - "Could not increase terminal size. Try running again with a larger window or smaller font size." - ) - - installApp = AddModelApplication(opt) - try: - installApp.run() - except KeyboardInterrupt as e: - if hasattr(installApp, "main_form"): - if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): - logger.info("Terminating subprocesses") - installApp.main_form.subprocess.terminate() - installApp.main_form.subprocess = None - raise e - process_and_execute(opt, installApp.install_selections) - - -# ------------------------------------- -def main(): - parser = argparse.ArgumentParser(description="InvokeAI model downloader") - parser.add_argument( - "--add", - nargs="*", - help="List of URLs, local paths or repo_ids of models to install", - ) - parser.add_argument( - "--delete", - nargs="*", - help="List of names of models to idelete", - ) - parser.add_argument( - "--full-precision", - dest="full_precision", - action=argparse.BooleanOptionalAction, - type=bool, - default=False, - help="use 32-bit weights instead of faster 16-bit weights", - ) - parser.add_argument( - "--yes", - "-y", - dest="yes_to_all", - action="store_true", - help='answer "yes" to all prompts', - ) - parser.add_argument( - "--default_only", - action="store_true", - help="Only install the default model", - ) - parser.add_argument( - "--list-models", - choices=[x.value for x in ModelType], - help="list installed models", - ) - parser.add_argument( - "--config_file", - "-c", - dest="config_file", - type=str, - default=None, - help="path to configuration file to create", - ) - parser.add_argument( - "--root_dir", - dest="root", - type=str, - default=None, - help="path to root of install directory", - ) - opt = parser.parse_args() - - invoke_args = [] - if opt.root: - invoke_args.extend(["--root", opt.root]) - if opt.full_precision: - invoke_args.extend(["--precision", "float32"]) - config.parse_args(invoke_args) - logger = InvokeAILogger().get_logger(config=config) - - if not config.model_conf_path.exists(): - logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.") - from invokeai.frontend.install.invokeai_configure import invokeai_configure - - invokeai_configure() - sys.exit(0) - - try: - select_and_download_models(opt) - except AssertionError as e: - logger.error(e) - sys.exit(-1) - except KeyboardInterrupt: - curses.nocbreak() - curses.echo() - curses.endwin() - logger.info("Goodbye! Come back soon.") - except WindowTooSmallException as e: - logger.error(str(e)) - except widget.NotEnoughSpaceForWidget as e: - if str(e).startswith("Height of 1 allocated"): - logger.error("Insufficient vertical space for the interface. Please make your window taller and try again") - input("Press any key to continue...") - except Exception as e: - if str(e).startswith("addwstr"): - logger.error( - "Insufficient horizontal space for the interface. Please make your window wider and try again." - ) - else: - print(f"An exception has occurred: {str(e)} Details:") - print(traceback.format_exc(), file=sys.stderr) - input("Press any key to continue...") - - -# ------------------------------------- -if __name__ == "__main__": - main() diff --git a/invokeai/frontend/merge/merge_diffusers.py.OLD b/invokeai/frontend/merge/merge_diffusers.py.OLD deleted file mode 100644 index b365198f879..00000000000 --- a/invokeai/frontend/merge/merge_diffusers.py.OLD +++ /dev/null @@ -1,438 +0,0 @@ -""" -invokeai.frontend.merge exports a single function called merge_diffusion_models(). - -It merges 2-3 models together and create a new InvokeAI-registered diffusion model. - -Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team -""" -import argparse -import curses -import re -import sys -from argparse import Namespace -from pathlib import Path -from typing import List, Optional, Tuple - -import npyscreen -from npyscreen import widget - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.install.install_helper import initialize_installer -from invokeai.backend.model_manager import ( - BaseModelType, - ModelFormat, - ModelType, - ModelVariantType, -) -from invokeai.backend.model_manager.merge import ModelMerger -from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox - -config = InvokeAIAppConfig.get_config() - -BASE_TYPES = [ - (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), - (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), - (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), -] - - -def _parse_args() -> Namespace: - parser = argparse.ArgumentParser(description="InvokeAI model merging") - parser.add_argument( - "--root_dir", - type=Path, - default=config.root, - help="Path to the invokeai runtime directory", - ) - parser.add_argument( - "--front_end", - "--gui", - dest="front_end", - action="store_true", - default=False, - help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.", - ) - parser.add_argument( - "--models", - dest="model_names", - type=str, - nargs="+", - help="Two to three model names to be merged", - ) - parser.add_argument( - "--base_model", - type=str, - choices=[x[0].value for x in BASE_TYPES], - help="The base model shared by the models to be merged", - ) - parser.add_argument( - "--merged_model_name", - "--destination", - dest="merged_model_name", - type=str, - help="Name of the output model. If not specified, will be the concatenation of the input model names.", - ) - parser.add_argument( - "--alpha", - type=float, - default=0.5, - help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models", - ) - parser.add_argument( - "--interpolation", - dest="interp", - type=str, - choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"], - default="weighted_sum", - help='Interpolation method to use. If three models are present, only "add_difference" will work.', - ) - parser.add_argument( - "--force", - action="store_true", - help="Try to merge models even if they are incompatible with each other", - ) - parser.add_argument( - "--clobber", - "--overwrite", - dest="clobber", - action="store_true", - help="Overwrite the merged model if --merged_model_name already exists", - ) - return parser.parse_args() - - -# ------------------------- GUI HERE ------------------------- -class mergeModelsForm(npyscreen.FormMultiPageAction): - interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"] - - def __init__(self, parentApp, name): - self.parentApp = parentApp - self.ALLOW_RESIZE = True - self.FIX_MINIMUM_SIZE_WHEN_CREATED = False - super().__init__(parentApp, name) - - @property - def model_record_store(self) -> ModelRecordServiceBase: - installer: ModelInstallServiceBase = self.parentApp.installer - return installer.record_store - - def afterEditing(self) -> None: - self.parentApp.setNextForm(None) - - def create(self) -> None: - window_height, window_width = curses.initscr().getmaxyx() - self.current_base = 0 - self.models = self.get_models(BASE_TYPES[self.current_base][0]) - self.model_names = [x[1] for x in self.models] - max_width = max([len(x) for x in self.model_names]) - max_width += 6 - horizontal_layout = max_width * 3 < window_width - - self.add_widget_intelligent( - npyscreen.FixedText, - color="CONTROL", - value="Select two models to merge and optionally a third.", - editable=False, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - color="CONTROL", - value="Use up and down arrows to move, to select an item, and to move from one field to the next.", - editable=False, - ) - self.nextrely += 1 - self.base_select = self.add_widget_intelligent( - SingleSelectColumns, - values=[x[1] for x in BASE_TYPES], - value=[self.current_base], - columns=4, - max_height=2, - relx=8, - scroll_exit=True, - ) - self.base_select.on_changed = self._populate_models - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 1", - color="GOOD", - editable=False, - rely=6 if horizontal_layout else None, - ) - self.model1 = self.add_widget_intelligent( - npyscreen.SelectOne, - values=self.model_names, - value=0, - max_height=len(self.model_names), - max_width=max_width, - scroll_exit=True, - rely=7, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 2", - color="GOOD", - editable=False, - relx=max_width + 3 if horizontal_layout else None, - rely=6 if horizontal_layout else None, - ) - self.model2 = self.add_widget_intelligent( - npyscreen.SelectOne, - name="(2)", - values=self.model_names, - value=1, - max_height=len(self.model_names), - max_width=max_width, - relx=max_width + 3 if horizontal_layout else None, - rely=7 if horizontal_layout else None, - scroll_exit=True, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 3", - color="GOOD", - editable=False, - relx=max_width * 2 + 3 if horizontal_layout else None, - rely=6 if horizontal_layout else None, - ) - models_plus_none = self.model_names.copy() - models_plus_none.insert(0, "None") - self.model3 = self.add_widget_intelligent( - npyscreen.SelectOne, - name="(3)", - values=models_plus_none, - value=0, - max_height=len(self.model_names) + 1, - max_width=max_width, - scroll_exit=True, - relx=max_width * 2 + 3 if horizontal_layout else None, - rely=7 if horizontal_layout else None, - ) - for m in [self.model1, self.model2, self.model3]: - m.when_value_edited = self.models_changed - self.merged_model_name = self.add_widget_intelligent( - TextBox, - name="Name for merged model:", - labelColor="CONTROL", - max_height=3, - value="", - scroll_exit=True, - ) - self.force = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Force merge of models created by different diffusers library versions", - labelColor="CONTROL", - value=True, - scroll_exit=True, - ) - self.nextrely += 1 - self.merge_method = self.add_widget_intelligent( - npyscreen.TitleSelectOne, - name="Merge Method:", - values=self.interpolations, - value=0, - labelColor="CONTROL", - max_height=len(self.interpolations) + 1, - scroll_exit=True, - ) - self.alpha = self.add_widget_intelligent( - FloatTitleSlider, - name="Weight (alpha) to assign to second and third models:", - out_of=1.0, - step=0.01, - lowest=0, - value=0.5, - labelColor="CONTROL", - scroll_exit=True, - ) - self.model1.editing = True - - def models_changed(self) -> None: - models = self.model1.values - selected_model1 = self.model1.value[0] - selected_model2 = self.model2.value[0] - selected_model3 = self.model3.value[0] - merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}" - self.merged_model_name.value = merged_model_name - - if selected_model3 > 0: - self.merge_method.values = ["add_difference ( A+(B-C) )"] - self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one. - else: - self.merge_method.values = self.interpolations - self.merge_method.value = 0 - - def on_ok(self) -> None: - if self.validate_field_values() and self.check_for_overwrite(): - self.parentApp.setNextForm(None) - self.editing = False - self.parentApp.merge_arguments = self.marshall_arguments() - npyscreen.notify("Starting the merge...") - else: - self.editing = True - - def on_cancel(self) -> None: - sys.exit(0) - - def marshall_arguments(self) -> dict: - model_keys = [x[0] for x in self.models] - models = [ - model_keys[self.model1.value[0]], - model_keys[self.model2.value[0]], - ] - if self.model3.value[0] > 0: - models.append(model_keys[self.model3.value[0] - 1]) - interp = "add_difference" - else: - interp = self.interpolations[self.merge_method.value[0]] - - args = { - "model_keys": models, - "alpha": self.alpha.value, - "interp": interp, - "force": self.force.value, - "merged_model_name": self.merged_model_name.value, - } - return args - - def check_for_overwrite(self) -> bool: - model_out = self.merged_model_name.value - if model_out not in self.model_names: - return True - else: - result: bool = npyscreen.notify_yes_no( - f"The chosen merged model destination, {model_out}, is already in use. Overwrite?" - ) - return result - - def validate_field_values(self) -> bool: - bad_fields = [] - model_names = self.model_names - selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]} - if self.model3.value[0] > 0: - selected_models.add(model_names[self.model3.value[0] - 1]) - if len(selected_models) < 2: - bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}") - if len(bad_fields) > 0: - message = "The following problems were detected and must be corrected:" - for problem in bad_fields: - message += f"\n* {problem}" - npyscreen.notify_confirm(message) - return False - else: - return True - - def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name - models = [ - (x.key, x.name) - for x in self.model_record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model) - if x.format == ModelFormat("diffusers") - and hasattr(x, "variant") - and x.variant == ModelVariantType("normal") - ] - return sorted(models, key=lambda x: x[1]) - - def _populate_models(self, value: List[int]) -> None: - base_model = BASE_TYPES[value[0]][0] - self.models = self.get_models(base_model) - self.model_names = [x[1] for x in self.models] - - models_plus_none = self.model_names.copy() - models_plus_none.insert(0, "None") - self.model1.values = self.model_names - self.model2.values = self.model_names - self.model3.values = models_plus_none - - self.display() - - -# npyscreen is untyped and causes mypy to get naggy -class Mergeapp(npyscreen.NPSAppManaged): # type: ignore - def __init__(self, installer: ModelInstallServiceBase): - """Initialize the npyscreen application.""" - super().__init__() - self.installer = installer - - def onStart(self) -> None: - npyscreen.setTheme(npyscreen.Themes.ElegantTheme) - self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") - - -def run_gui(args: Namespace) -> None: - installer = initialize_installer(config) - mergeapp = Mergeapp(installer) - mergeapp.run() - merge_args = mergeapp.merge_arguments - merger = ModelMerger(installer) - merger.merge_diffusion_models_and_save(**merge_args) - logger.info(f'Models merged into new model: "{merge_args.merged_model_name}".') - - -def run_cli(args: Namespace) -> None: - assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" - assert ( - args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3 - ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." - - if not args.merged_model_name: - args.merged_model_name = "+".join(args.model_names) - logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - - installer = initialize_installer(config) - store = installer.record_store - assert ( - len(store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber - ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - - merger = ModelMerger(installer) - model_keys = [] - for name in args.model_names: - if len(name) == 32 and re.match(r"^[0-9a-f]$", name): - model_keys.append(name) - else: - models = store.search_by_attr( - model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) - ) - assert len(models) > 0, f"{name}: Unknown model" - assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." - model_keys.append(models[0].key) - - merger.merge_diffusion_models_and_save( - alpha=args.alpha, - model_keys=model_keys, - merged_model_name=args.merged_model_name, - interp=args.interp, - force=args.force, - ) - logger.info(f'Models merged into new model: "{args.merged_model_name}".') - - -def main() -> None: - args = _parse_args() - if args.root_dir: - config.parse_args(["--root", str(args.root_dir)]) - else: - config.parse_args([]) - - try: - if args.front_end: - run_gui(args) - else: - run_cli(args) - except widget.NotEnoughSpaceForWidget as e: - if str(e).startswith("Height of 1 allocated"): - logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge") - else: - logger.error("Not enough room for the user interface. Try making this window larger.") - sys.exit(-1) - except Exception as e: - logger.error(str(e)) - sys.exit(-1) - except KeyboardInterrupt: - sys.exit(-1) - - -if __name__ == "__main__": - main() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 5694432ebdc..55f7e865410 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -20,7 +20,7 @@ ) from invokeai.app.services.model_records import UnknownModelException from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 OS = platform.uname().system diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 852e1da979c..57515ac81b1 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -26,7 +26,7 @@ ) from invokeai.backend.model_manager.metadata import BaseMetadata from invokeai.backend.util.logging import InvokeAILogger -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.fixtures.sqlite_database import create_mock_sqlite_database diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 6a3ec510a2c..9ed3c9bc507 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -2,7 +2,7 @@ import torch from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType from invokeai.backend.util.test_utils import install_and_load_model diff --git a/tests/backend/model_manager_2/data/invokeai_root/README b/tests/backend/model_manager/data/invokeai_root/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/README rename to tests/backend/model_manager/data/invokeai_root/README diff --git a/tests/backend/model_manager_2/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml b/tests/backend/model_manager/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml rename to tests/backend/model_manager/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml diff --git a/tests/backend/model_manager_2/data/invokeai_root/databases/README b/tests/backend/model_manager/data/invokeai_root/databases/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/databases/README rename to tests/backend/model_manager/data/invokeai_root/databases/README diff --git a/tests/backend/model_manager_2/data/invokeai_root/models/README b/tests/backend/model_manager/data/invokeai_root/models/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/models/README rename to tests/backend/model_manager/data/invokeai_root/models/README diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/model_index.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/model_index.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/model_index.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/model_index.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/scheduler/scheduler_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/scheduler/scheduler_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/scheduler/scheduler_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/scheduler/scheduler_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/merges.txt b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/merges.txt similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/merges.txt rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/merges.txt diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/vocab.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/vocab.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/vocab.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/vocab.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/merges.txt b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/merges.txt similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/merges.txt rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/merges.txt diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/vocab.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/vocab.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/vocab.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/vocab.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test_embedding.safetensors b/tests/backend/model_manager/data/test_files/test_embedding.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test_embedding.safetensors rename to tests/backend/model_manager/data/test_files/test_embedding.safetensors diff --git a/tests/backend/model_manager_2/model_loading/test_model_load.py b/tests/backend/model_manager/model_loading/test_model_load.py similarity index 61% rename from tests/backend/model_manager_2/model_loading/test_model_load.py rename to tests/backend/model_manager/model_loading/test_model_load.py index a7a64e91ac0..38d9b8afb8c 100644 --- a/tests/backend/model_manager_2/model_loading/test_model_load.py +++ b/tests/backend/model_manager/model_loading/test_model_load.py @@ -5,17 +5,16 @@ from pathlib import Path from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw -from invokeai.backend.model_manager.load import AnyModelLoader -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from invokeai.app.services.model_load import ModelLoadServiceBase +from invokeai.backend.textual_inversion import TextualInversionModelRaw +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 - -def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path): +def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path): store = mm2_installer.record_store matches = store.search_by_attr(model_name="test_embedding") assert len(matches) == 0 key = mm2_installer.register_path(embedding_file) - loaded_model = mm2_loader.load_model(store.get_model(key)) + loaded_model = mm2_loader.load_model_by_config(store.get_model(key)) assert loaded_model is not None assert loaded_model.config.key == key with loaded_model as model: diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py similarity index 80% rename from tests/backend/model_manager_2/model_manager_2_fixtures.py rename to tests/backend/model_manager/model_manager_fixtures.py index ebdc9cb5cd6..5f7f44c0188 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -6,24 +6,27 @@ from typing import Any, Dict, List import pytest +from pytest import FixtureRequest from pydantic import BaseModel from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueService +from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService +from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL -from invokeai.app.services.model_records import ModelRecordServiceSQL +from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( BaseModelType, ModelFormat, ModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.util.logging import InvokeAILogger -from tests.backend.model_manager_2.model_metadata.metadata_examples import ( +from tests.backend.model_manager.model_metadata.metadata_examples import ( RepoCivitaiModelMetadata1, RepoCivitaiVersionMetadata1, RepoHFMetadata1, @@ -86,22 +89,71 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: app_config = InvokeAIAppConfig( root=mm2_root_dir, models_dir=mm2_root_dir / "models", + log_level="info", ) return app_config @pytest.fixture -def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader: - logger = InvokeAILogger.get_logger(config=mm2_app_config) +def mm2_download_queue(mm2_session: Session, + request: FixtureRequest + ) -> DownloadQueueServiceBase: + download_queue = DownloadQueueService(requests_session=mm2_session) + download_queue.start() + + def stop_queue() -> None: + download_queue.stop() + + request.addfinalizer(stop_queue) + return download_queue + +@pytest.fixture +def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: + return mm2_record_store.metadata_store + +@pytest.fixture +def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: ram_cache = ModelCache( - logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size + logger=InvokeAILogger.get_logger(), + max_cache_size=mm2_app_config.ram_cache_size, + max_vram_cache_size=mm2_app_config.vram_cache_size ) convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) - return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache) + return ModelLoadService(app_config=mm2_app_config, + record_store=mm2_record_store, + ram_cache=ram_cache, + convert_cache=convert_cache, + ) + +@pytest.fixture +def mm2_installer(mm2_app_config: InvokeAIAppConfig, + mm2_download_queue: DownloadQueueServiceBase, + mm2_session: Session, + request: FixtureRequest, + ) -> ModelInstallServiceBase: + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(mm2_app_config, logger) + events = DummyEventService() + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + + installer = ModelInstallService( + app_config=mm2_app_config, + record_store=store, + download_queue=mm2_download_queue, + event_bus=events, + session=mm2_session, + ) + installer.start() + + def stop_installer() -> None: + installer.stop() + + request.addfinalizer(stop_installer) + return installer @pytest.fixture -def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: +def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) @@ -161,11 +213,15 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL store.add_model("test_config_5", raw5) return store - @pytest.fixture -def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: - return mm2_record_store.metadata_store - +def mm2_model_manager(mm2_record_store: ModelRecordServiceBase, + mm2_installer: ModelInstallServiceBase, + mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase: + return ModelManagerService( + store=mm2_record_store, + install=mm2_installer, + load=mm2_loader + ) @pytest.fixture def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: @@ -252,22 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: return sess -@pytest.fixture -def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase: - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(mm2_app_config, logger) - events = DummyEventService() - store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) - - download_queue = DownloadQueueService(requests_session=mm2_session) - download_queue.start() - - installer = ModelInstallService( - app_config=mm2_app_config, - record_store=store, - download_queue=download_queue, - event_bus=events, - session=mm2_session, - ) - installer.start() - return installer diff --git a/tests/backend/model_manager_2/model_metadata/metadata_examples.py b/tests/backend/model_manager/model_metadata/metadata_examples.py similarity index 100% rename from tests/backend/model_manager_2/model_metadata/metadata_examples.py rename to tests/backend/model_manager/model_metadata/metadata_examples.py diff --git a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py b/tests/backend/model_manager/model_metadata/test_model_metadata.py similarity index 99% rename from tests/backend/model_manager_2/model_metadata/test_model_metadata.py rename to tests/backend/model_manager/model_metadata/test_model_metadata.py index f61eab1b5d0..09b18916d38 100644 --- a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py +++ b/tests/backend/model_manager/model_metadata/test_model_metadata.py @@ -19,7 +19,7 @@ UnknownMetadataException, ) from invokeai.backend.model_manager.util import select_hf_files -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None: diff --git a/tests/backend/model_management/test_libc_util.py b/tests/backend/model_manager/test_libc_util.py similarity index 88% rename from tests/backend/model_management/test_libc_util.py rename to tests/backend/model_manager/test_libc_util.py index e13a2fd3a2f..4309dc7c34c 100644 --- a/tests/backend/model_management/test_libc_util.py +++ b/tests/backend/model_manager/test_libc_util.py @@ -1,6 +1,6 @@ import pytest -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 +from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2 def test_libc_util_mallinfo2(): diff --git a/tests/backend/model_management/test_lora.py b/tests/backend/model_manager/test_lora.py similarity index 96% rename from tests/backend/model_management/test_lora.py rename to tests/backend/model_manager/test_lora.py index 14bcc87c892..e124bb68efc 100644 --- a/tests/backend/model_management/test_lora.py +++ b/tests/backend/model_manager/test_lora.py @@ -5,8 +5,8 @@ import pytest import torch -from invokeai.backend.model_management.lora import ModelPatcher -from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.lora import LoRALayer, LoRAModelRaw @pytest.mark.parametrize( diff --git a/tests/backend/model_management/test_memory_snapshot.py b/tests/backend/model_manager/test_memory_snapshot.py similarity index 87% rename from tests/backend/model_management/test_memory_snapshot.py rename to tests/backend/model_manager/test_memory_snapshot.py index 216cd62171d..87ec8c34ee0 100644 --- a/tests/backend/model_management/test_memory_snapshot.py +++ b/tests/backend/model_manager/test_memory_snapshot.py @@ -1,8 +1,7 @@ import pytest -from invokeai.backend.model_management.libc_util import Struct_mallinfo2 -from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff - +from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2 +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff def test_memory_snapshot_capture(): """Smoke test of MemorySnapshot.capture().""" @@ -26,6 +25,7 @@ def test_memory_snapshot_capture(): def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2): """Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields.""" msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2) + print(msg) expected_lines = 0 if snapshot_1 is not None and snapshot_2 is not None: diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_manager/test_model_load_optimization.py similarity index 96% rename from tests/backend/model_management/test_model_load_optimization.py rename to tests/backend/model_manager/test_model_load_optimization.py index a4fe1dd5974..f627f3a2982 100644 --- a/tests/backend/model_management/test_model_load_optimization.py +++ b/tests/backend/model_manager/test_model_load_optimization.py @@ -1,7 +1,7 @@ import pytest import torch -from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init +from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init @pytest.mark.parametrize( diff --git a/tests/backend/model_manager_2/util/test_hf_model_select.py b/tests/backend/model_manager/util/test_hf_model_select.py similarity index 100% rename from tests/backend/model_manager_2/util/test_hf_model_select.py rename to tests/backend/model_manager/util/test_hf_model_select.py diff --git a/tests/conftest.py b/tests/conftest.py index 6e7d559be44..1c816002296 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,2 @@ # conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory # without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html) - - -# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not -# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. -from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401 From 98ff731c3de37ebc005faa7a401e41001bd9801e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 17:27:42 +1100 Subject: [PATCH 3/3] final tidying before marking PR as ready for review - Replace AnyModelLoader with ModelLoaderRegistry - Fix type check errors in multiple files - Remove apparently unneeded `get_model_config_enum()` method from model manager - Remove last vestiges of old model manager - Updated tests and documentation resolve conflict with seamless.py --- docs/contributing/MODEL_MANAGER.md | 129 +- .../{model_manager_v2.py => model_manager.py} | 38 +- invokeai/app/api/routers/models.py | 426 ---- invokeai/app/api_app.py | 36 +- invokeai/app/invocations/compel.py | 4 +- invokeai/app/services/config/config_base.py | 2 +- .../invocation_stats/invocation_stats_base.py | 11 +- .../model_install/model_install_base.py | 2 +- .../services/model_load/model_load_base.py | 48 +- .../services/model_load/model_load_default.py | 100 +- .../app/services/model_manager/__init__.py | 2 +- .../model_manager/model_manager_base.py | 33 + .../model_manager/model_manager_default.py | 60 +- .../app/services/shared/invocation_context.py | 4 +- invokeai/backend/install/migrate_to_3.py | 591 ------ .../backend/install/model_install_backend.py | 637 ------ invokeai/backend/ip_adapter/ip_adapter.py | 2 +- invokeai/backend/lora.py | 2 + .../backend/model_management_OLD/README.md | 27 - .../backend/model_management_OLD/__init__.py | 20 - .../convert_ckpt_to_diffusers.py | 1739 ----------------- .../detect_baked_in_vae.py | 31 - invokeai/backend/model_management_OLD/lora.py | 582 ------ .../model_management_OLD/memory_snapshot.py | 99 - .../model_management_OLD/model_cache.py | 553 ------ .../model_load_optimizations.py | 30 - .../model_management_OLD/model_manager.py | 1121 ----------- .../model_management_OLD/model_merge.py | 140 -- .../model_management_OLD/model_probe.py | 664 ------- .../model_management_OLD/model_search.py | 112 -- .../model_management_OLD/models/__init__.py | 167 -- .../model_management_OLD/models/base.py | 681 ------- .../models/clip_vision.py | 82 - .../model_management_OLD/models/controlnet.py | 162 -- .../model_management_OLD/models/ip_adapter.py | 98 - .../model_management_OLD/models/lora.py | 696 ------- .../model_management_OLD/models/sdxl.py | 148 -- .../models/stable_diffusion.py | 337 ---- .../models/stable_diffusion_onnx.py | 150 -- .../models/t2i_adapter.py | 102 - .../models/textual_inversion.py | 87 - .../model_management_OLD/models/vae.py | 179 -- .../backend/model_management_OLD/seamless.py | 84 - invokeai/backend/model_management_OLD/util.py | 79 - invokeai/backend/model_manager/__init__.py | 40 +- .../backend/model_manager/load/__init__.py | 14 +- .../backend/model_manager/load/load_base.py | 130 +- .../model_manager/load/load_default.py | 54 +- .../model_manager/load/memory_snapshot.py | 2 +- .../load/model_loader_registry.py | 122 ++ .../load/model_loaders/controlnet.py | 6 +- .../load/model_loaders/generic_diffusers.py | 66 +- .../load/model_loaders/ip_adapter.py | 5 +- .../model_manager/load/model_loaders/lora.py | 8 +- .../model_manager/load/model_loaders/onnx.py | 13 +- .../load/model_loaders/stable_diffusion.py | 13 +- .../load/model_loaders/textual_inversion.py | 12 +- .../model_manager/load/model_loaders/vae.py | 8 +- .../model_manager/load/optimizations.py | 13 +- invokeai/backend/model_manager/merge.py | 4 +- .../model_manager/metadata/metadata_base.py | 9 +- invokeai/backend/model_manager/probe.py | 3 +- invokeai/backend/model_manager/search.py | 6 +- .../backend/model_manager/util/libc_util.py | 7 +- .../backend/model_manager/util/model_util.py | 20 +- invokeai/backend/onnx/onnx_runtime.py | 3 +- invokeai/backend/raw_model.py | 1 + invokeai/backend/stable_diffusion/seamless.py | 94 +- invokeai/backend/textual_inversion.py | 2 + invokeai/backend/util/test_utils.py | 4 +- .../model_loading/test_model_load.py | 21 +- .../model_manager/model_manager_fixtures.py | 54 +- tests/backend/model_manager/test_lora.py | 2 +- .../model_manager/test_memory_snapshot.py | 3 +- 74 files changed, 673 insertions(+), 10363 deletions(-) rename invokeai/app/api/routers/{model_manager_v2.py => model_manager.py} (97%) delete mode 100644 invokeai/app/api/routers/models.py delete mode 100644 invokeai/backend/install/migrate_to_3.py delete mode 100644 invokeai/backend/install/model_install_backend.py delete mode 100644 invokeai/backend/model_management_OLD/README.md delete mode 100644 invokeai/backend/model_management_OLD/__init__.py delete mode 100644 invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py delete mode 100644 invokeai/backend/model_management_OLD/detect_baked_in_vae.py delete mode 100644 invokeai/backend/model_management_OLD/lora.py delete mode 100644 invokeai/backend/model_management_OLD/memory_snapshot.py delete mode 100644 invokeai/backend/model_management_OLD/model_cache.py delete mode 100644 invokeai/backend/model_management_OLD/model_load_optimizations.py delete mode 100644 invokeai/backend/model_management_OLD/model_manager.py delete mode 100644 invokeai/backend/model_management_OLD/model_merge.py delete mode 100644 invokeai/backend/model_management_OLD/model_probe.py delete mode 100644 invokeai/backend/model_management_OLD/model_search.py delete mode 100644 invokeai/backend/model_management_OLD/models/__init__.py delete mode 100644 invokeai/backend/model_management_OLD/models/base.py delete mode 100644 invokeai/backend/model_management_OLD/models/clip_vision.py delete mode 100644 invokeai/backend/model_management_OLD/models/controlnet.py delete mode 100644 invokeai/backend/model_management_OLD/models/ip_adapter.py delete mode 100644 invokeai/backend/model_management_OLD/models/lora.py delete mode 100644 invokeai/backend/model_management_OLD/models/sdxl.py delete mode 100644 invokeai/backend/model_management_OLD/models/stable_diffusion.py delete mode 100644 invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py delete mode 100644 invokeai/backend/model_management_OLD/models/t2i_adapter.py delete mode 100644 invokeai/backend/model_management_OLD/models/textual_inversion.py delete mode 100644 invokeai/backend/model_management_OLD/models/vae.py delete mode 100644 invokeai/backend/model_management_OLD/seamless.py delete mode 100644 invokeai/backend/model_management_OLD/util.py create mode 100644 invokeai/backend/model_manager/load/model_loader_registry.py diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index b19699de73d..8351904b619 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1531,23 +1531,29 @@ Here is a typical initialization pattern: ``` from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.app.services.model_load import ModelLoadService +from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry config = InvokeAIAppConfig.get_config() -store = ModelRecordServiceBase.open(config) -loader = ModelLoadService(config, store) +ram_cache = ModelCache( + max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger +) +convert_cache = ModelConvertCache( + cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size +) +loader = ModelLoadService( + app_config=config, + ram_cache=ram_cache, + convert_cache=convert_cache, + registry=ModelLoaderRegistry +) ``` -Note that we are relying on the contents of the application -configuration to choose the implementation of -`ModelRecordServiceBase`. +### load_model(model_config, [submodel_type], [context]) -> LoadedModel -### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel - -The `load_model_by_key()` method receives the unique key that -identifies the model. It loads the model into memory, gets the model -ready for use, and returns a `LoadedModel` object. +The `load_model()` method takes an `AnyModelConfig` returned by +`ModelRecordService.get_model()` and returns the corresponding loaded +model. It loads the model into memory, gets the model ready for use, +and returns a `LoadedModel` object. The optional second argument, `subtype` is a `SubModelType` string enum, such as "vae". It is mandatory when used with a main model, and @@ -1593,25 +1599,6 @@ with model_info as vae: - `ModelNotFoundException` -- key in database but model not found at path - `NotImplementedException` -- the loader doesn't know how to load this type of model -### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel - -This is similar to `load_model_by_key`, but instead it accepts the -combination of the model's name, type and base, which it passes to the -model record config store for retrieval. If successful, this method -returns a `LoadedModel`. It can raise the following exceptions: - -``` -UnknownModelException -- model with these attributes not known -NotImplementedException -- the loader doesn't know how to load this type of model -ValueError -- more than one model matches this combination of base/type/name -``` - -### load_model_by_config(config, [submodel], [context]) -> LoadedModel - -This method takes an `AnyModelConfig` returned by -ModelRecordService.get_model() and returns the corresponding loaded -model. It may raise a `NotImplementedException`. - ### Emitting model loading events When the `context` argument is passed to `load_model_*()`, it will @@ -1656,7 +1643,7 @@ onnx models. To install a new loader, place it in `invokeai/backend/model_manager/load/model_loaders`. Inherit from -`ModelLoader` and use the `@AnyModelLoader.register()` decorator to +`ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to indicate what type of models the loader can handle. Here is a complete example from `generic_diffusers.py`, which is able @@ -1674,12 +1661,11 @@ from invokeai.backend.model_manager import ( ModelType, SubModelType, ) -from ..load_base import AnyModelLoader -from ..load_default import ModelLoader +from .. import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) class GenericDiffusersLoader(ModelLoader): """Class to load simple diffusers models.""" @@ -1728,3 +1714,74 @@ model. It does whatever it needs to do to get the model into diffusers format, and returns the Path of the resulting model. (The path should ordinarily be the same as `output_path`.) +## The ModelManagerService object + +For convenience, the API provides a `ModelManagerService` object which +gives a single point of access to the major model manager +services. This object is created at initialization time and can be +found in the global `ApiDependencies.invoker.services.model_manager` +object, or in `context.services.model_manager` from within an +invocation. + +In the examples below, we have retrieved the manager using: +``` +mm = ApiDependencies.invoker.services.model_manager +``` + +The following properties and methods will be available: + +### mm.store + +This retrieves the `ModelRecordService` associated with the +manager. Example: + +``` +configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5') +``` + +### mm.install + +This retrieves the `ModelInstallService` associated with the manager. +Example: + +``` +job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`) +``` + +### mm.load + +This retrieves the `ModelLoaderService` associated with the manager. Example: + +``` +configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5') +assert len(configs) > 0 + +loaded_model = mm.load.load_model(configs[0]) +``` + +The model manager also offers a few convenience shortcuts for loading +models: + +### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel + +Same as `mm.load.load_model()`. + +### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel + +This accepts the combination of the model's name, type and base, which +it passes to the model record config store for retrieval. If a unique +model config is found, this method returns a `LoadedModel`. It can +raise the following exceptions: + +``` +UnknownModelException -- model with these attributes not known +NotImplementedException -- the loader doesn't know how to load this type of model +ValueError -- more than one model matches this combination of base/type/name +``` + +### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel + +This method takes a model key, looks it up using the +`ModelRecordServiceBase` object in `mm.store`, and passes the returned +model configuration to `load_model_by_config()`. It may raise a +`NotImplementedException`. diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager.py similarity index 97% rename from invokeai/app/api/routers/model_manager_v2.py rename to invokeai/app/api/routers/model_manager.py index 2471e0d8c9b..6b7111dd2ce 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager.py @@ -35,7 +35,7 @@ from ..dependencies import ApiDependencies -model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) +model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) class ModelsList(BaseModel): @@ -135,7 +135,7 @@ class ModelTagSet(BaseModel): ############################################################################## -@model_manager_v2_router.get( +@model_manager_router.get( "/", operation_id="list_model_records", ) @@ -164,7 +164,7 @@ async def list_model_records( return ModelsList(models=found_models) -@model_manager_v2_router.get( +@model_manager_router.get( "/i/{key}", operation_id="get_model_record", responses={ @@ -188,7 +188,7 @@ async def get_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_manager_v2_router.get("/summary", operation_id="list_model_summary") +@model_manager_router.get("/summary", operation_id="list_model_summary") async def list_model_summary( page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of models per page"), @@ -200,7 +200,7 @@ async def list_model_summary( return results -@model_manager_v2_router.get( +@model_manager_router.get( "/meta/i/{key}", operation_id="get_model_metadata", responses={ @@ -223,7 +223,7 @@ async def get_model_metadata( return result -@model_manager_v2_router.get( +@model_manager_router.get( "/tags", operation_id="list_tags", ) @@ -234,7 +234,7 @@ async def list_tags() -> Set[str]: return result -@model_manager_v2_router.get( +@model_manager_router.get( "/tags/search", operation_id="search_by_metadata_tags", ) @@ -247,7 +247,7 @@ async def search_by_metadata_tags( return ModelsList(models=results) -@model_manager_v2_router.patch( +@model_manager_router.patch( "/i/{key}", operation_id="update_model_record", responses={ @@ -281,7 +281,7 @@ async def update_model_record( return model_response -@model_manager_v2_router.delete( +@model_manager_router.delete( "/i/{key}", operation_id="del_model_record", responses={ @@ -311,7 +311,7 @@ async def del_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_manager_v2_router.post( +@model_manager_router.post( "/i/", operation_id="add_model_record", responses={ @@ -349,7 +349,7 @@ async def add_model_record( return result -@model_manager_v2_router.post( +@model_manager_router.post( "/heuristic_import", operation_id="heuristic_import_model", responses={ @@ -416,7 +416,7 @@ async def heuristic_import( return result -@model_manager_v2_router.post( +@model_manager_router.post( "/install", operation_id="import_model", responses={ @@ -516,7 +516,7 @@ async def import_model( return result -@model_manager_v2_router.get( +@model_manager_router.get( "/import", operation_id="list_model_install_jobs", ) @@ -544,7 +544,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: return jobs -@model_manager_v2_router.get( +@model_manager_router.get( "/import/{id}", operation_id="get_model_install_job", responses={ @@ -564,7 +564,7 @@ async def get_model_install_job(id: int = Path(description="Model install id")) raise HTTPException(status_code=404, detail=str(e)) -@model_manager_v2_router.delete( +@model_manager_router.delete( "/import/{id}", operation_id="cancel_model_install_job", responses={ @@ -583,7 +583,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job installer.cancel_job(job) -@model_manager_v2_router.patch( +@model_manager_router.patch( "/import", operation_id="prune_model_install_jobs", responses={ @@ -597,7 +597,7 @@ async def prune_model_install_jobs() -> Response: return Response(status_code=204) -@model_manager_v2_router.patch( +@model_manager_router.patch( "/sync", operation_id="sync_models_to_config", responses={ @@ -616,7 +616,7 @@ async def sync_models_to_config() -> Response: return Response(status_code=204) -@model_manager_v2_router.put( +@model_manager_router.put( "/convert/{key}", operation_id="convert_model", responses={ @@ -694,7 +694,7 @@ async def convert_model( return new_config -@model_manager_v2_router.put( +@model_manager_router.put( "/merge", operation_id="merge", responses={ diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py deleted file mode 100644 index 0aa7aa0ecba..00000000000 --- a/invokeai/app/api/routers/models.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein - -import pathlib -from typing import Annotated, List, Literal, Optional, Union - -from fastapi import Body, Path, Query, Response -from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter -from starlette.exceptions import HTTPException - -from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType -from invokeai.backend.model_management.models import ( - OPENAPI_MODEL_CONFIGS, - InvalidModelException, - ModelNotFoundException, - SchedulerPredictionType, -) - -from ..dependencies import ApiDependencies - -models_router = APIRouter(prefix="/v1/models", tags=["models"]) - -UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse) - -ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelResponseValidator = TypeAdapter(ImportModelResponse) - -ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse) - -MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] - - -class ModelsList(BaseModel): - models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] - - model_config = ConfigDict(use_enum_values=True) - - -ModelsListValidator = TypeAdapter(ModelsList) - - -@models_router.get( - "/", - operation_id="list_models", - responses={200: {"model": ModelsList}}, -) -async def list_models( - base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), - model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), -) -> ModelsList: - """Gets a list of models""" - if base_models and len(base_models) > 0: - models_raw = [] - for base_model in base_models: - models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) - else: - models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) - models = ModelsListValidator.validate_python({"models": models_raw}) - return models - - -@models_router.patch( - "/{base_model}/{model_type}/{model_name}", - operation_id="update_model", - responses={ - 200: {"description": "The model was updated successfully"}, - 400: {"description": "Bad request"}, - 404: {"description": "The model could not be found"}, - 409: {"description": "There is already a model corresponding to the new name"}, - }, - status_code=200, - response_model=UpdateModelResponse, -) -async def update_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> UpdateModelResponse: - """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" - logger = ApiDependencies.invoker.services.logger - - try: - previous_info = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - - # rename operation requested - if info.model_name != model_name or info.base_model != base_model: - ApiDependencies.invoker.services.model_manager.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=info.model_name, - new_base=info.base_model, - ) - logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}") - # update information to support an update of attributes - model_name = info.model_name - base_model = info.base_model - new_info = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - if new_info.get("path") != previous_info.get( - "path" - ): # model manager moved model path during rename - don't overwrite it - info.path = new_info.get("path") - - # replace empty string values with None/null to avoid phenomenon of vae: '' - info_dict = info.model_dump() - info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} - - ApiDependencies.invoker.services.model_manager.update_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - model_attributes=info_dict, - ) - - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - model_response = UpdateModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - except Exception as e: - logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) - - return model_response - - -@models_router.post( - "/import", - operation_id="import_model", - responses={ - 201: {"description": "The model imported successfully"}, - 404: {"description": "The model could not be found"}, - 415: {"description": "Unrecognized file/folder format"}, - 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - }, - status_code=201, - response_model=ImportModelResponse, -) -async def import_model( - location: str = Body(description="A model path, repo_id or URL to import"), - prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body( - description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints", - default=None, - ), -) -> ImportModelResponse: - """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically""" - - location = location.strip("\"' ") - items_to_import = {location} - prediction_types = {x.value: x for x in SchedulerPredictionType} - logger = ApiDependencies.invoker.services.logger - - try: - installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import=items_to_import, - prediction_type_helper=lambda x: prediction_types.get(prediction_type), - ) - info = installed_models.get(location) - - if not info: - logger.error("Import failed") - raise HTTPException(status_code=415) - - logger.info(f"Successfully imported {location}, got {info}") - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.name, base_model=info.base_model, model_type=info.model_type - ) - return ImportModelResponseValidator.validate_python(model_raw) - - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - except InvalidModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=415) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - - -@models_router.post( - "/add", - operation_id="add_model", - responses={ - 201: {"description": "The model added successfully"}, - 404: {"description": "The model could not be found"}, - 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - }, - status_code=201, - response_model=ImportModelResponse, -) -async def add_model( - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> ImportModelResponse: - """Add a model using the configuration information appropriate for its type. Only local models can be added by path""" - - logger = ApiDependencies.invoker.services.logger - - try: - ApiDependencies.invoker.services.model_manager.add_model( - info.model_name, - info.base_model, - info.model_type, - model_attributes=info.model_dump(), - ) - logger.info(f"Successfully added {info.model_name}") - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.model_name, - base_model=info.base_model, - model_type=info.model_type, - ) - return ImportModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - - -@models_router.delete( - "/{base_model}/{model_type}/{model_name}", - operation_id="del_model", - responses={ - 204: {"description": "Model deleted successfully"}, - 404: {"description": "Model not found"}, - }, - status_code=204, - response_model=None, -) -async def delete_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), -) -> Response: - """Delete Model""" - logger = ApiDependencies.invoker.services.logger - - try: - ApiDependencies.invoker.services.model_manager.del_model( - model_name, base_model=base_model, model_type=model_type - ) - logger.info(f"Deleted model: {model_name}") - return Response(status_code=204) - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - - -@models_router.put( - "/convert/{base_model}/{model_type}/{model_name}", - operation_id="convert_model", - responses={ - 200: {"description": "Model converted successfully"}, - 400: {"description": "Bad request"}, - 404: {"description": "Model not found"}, - }, - status_code=200, - response_model=ConvertModelResponse, -) -async def convert_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - convert_dest_directory: Optional[str] = Query( - default=None, description="Save the converted model to the designated directory" - ), -) -> ConvertModelResponse: - """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" - logger = ApiDependencies.invoker.services.logger - try: - logger.info(f"Converting model: {model_name}") - dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None - ApiDependencies.invoker.services.model_manager.convert_model( - model_name, - base_model=base_model, - model_type=model_type, - convert_dest_directory=dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name, base_model=base_model, model_type=model_type - ) - response = ConvertModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return response - - -@models_router.get( - "/search", - operation_id="search_for_models", - responses={ - 200: {"description": "Directory searched successfully"}, - 404: {"description": "Invalid directory path"}, - }, - status_code=200, - response_model=List[pathlib.Path], -) -async def search_for_models( - search_path: pathlib.Path = Query(description="Directory path to search for models"), -) -> List[pathlib.Path]: - if not search_path.is_dir(): - raise HTTPException( - status_code=404, - detail=f"The search path '{search_path}' does not exist or is not directory", - ) - return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) - - -@models_router.get( - "/ckpt_confs", - operation_id="list_ckpt_configs", - responses={ - 200: {"description": "paths retrieved successfully"}, - }, - status_code=200, - response_model=List[pathlib.Path], -) -async def list_ckpt_configs() -> List[pathlib.Path]: - """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" - return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() - - -@models_router.post( - "/sync", - operation_id="sync_to_config", - responses={ - 201: {"description": "synchronization successful"}, - }, - status_code=201, - response_model=bool, -) -async def sync_to_config() -> bool: - """Call after making changes to models.yaml, autoimport directories or models directory to synchronize - in-memory data structures with disk data structures.""" - ApiDependencies.invoker.services.model_manager.sync_to_config() - return True - - -# There's some weird pydantic-fastapi behaviour that requires this to be a separate class -# TODO: After a few updates, see if it works inside the route operation handler? -class MergeModelsBody(BaseModel): - model_names: List[str] = Field(description="model name", min_length=2, max_length=3) - merged_model_name: Optional[str] = Field(description="Name of destination model") - alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5) - interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method") - force: Optional[bool] = Field( - description="Force merging of models created with different versions of diffusers", - default=False, - ) - - merge_dest_directory: Optional[str] = Field( - description="Save the merged model to the designated directory (with 'merged_model_name' appended)", - default=None, - ) - - model_config = ConfigDict(protected_namespaces=()) - - -@models_router.put( - "/merge/{base_model}", - operation_id="merge_models", - responses={ - 200: {"description": "Model converted successfully"}, - 400: {"description": "Incompatible models"}, - 404: {"description": "One or more models not found"}, - }, - status_code=200, - response_model=MergeModelResponse, -) -async def merge_models( - body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)], - base_model: BaseModelType = Path(description="Base model"), -) -> MergeModelResponse: - """Convert a checkpoint model into a diffusers model""" - logger = ApiDependencies.invoker.services.logger - try: - logger.info( - f"Merging models: {body.model_names} into {body.merge_dest_directory or ''}/{body.merged_model_name}" - ) - dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None - result = ApiDependencies.invoker.services.model_manager.merge_models( - model_names=body.model_names, - base_model=base_model, - merged_model_name=body.merged_model_name or "+".join(body.model_names), - alpha=body.alpha, - interp=body.interp, - force=body.force, - merge_dest_directory=dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - result.name, - base_model=base_model, - model_type=ModelType.Main, - ) - response = ConvertModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException: - raise HTTPException( - status_code=404, - detail=f"One or more of the models '{body.model_names}' not found", - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return response diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 1831b54c13c..149d47fb962 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -48,7 +48,7 @@ boards, download_queue, images, - model_manager_v2, + model_manager, session_queue, sessions, utilities, @@ -113,7 +113,7 @@ async def shutdown_event() -> None: app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") -app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api") +app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") @@ -175,21 +175,23 @@ def custom_openapi() -> dict[str, Any]: invoker_schema["class"] = "invocation" openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output" - from invokeai.backend.model_management.models import get_model_config_enums - - for model_config_format_enum in set(get_model_config_enums()): - name = model_config_format_enum.__qualname__ - - if name in openapi_schema["components"]["schemas"]: - # print(f"Config with name {name} already defined") - continue - - openapi_schema["components"]["schemas"][name] = { - "title": name, - "description": "An enumeration.", - "type": "string", - "enum": [v.value for v in model_config_format_enum], - } + # This code no longer seems to be necessary? + # Leave it here just in case + # + # from invokeai.backend.model_manager import get_model_config_formats + # formats = get_model_config_formats() + # for model_config_name, enum_set in formats.items(): + + # if model_config_name in openapi_schema["components"]["schemas"]: + # # print(f"Config with name {name} already defined") + # continue + + # openapi_schema["components"]["schemas"][model_config_name] = { + # "title": model_config_name, + # "description": "An enumeration.", + # "type": "string", + # "enum": [v.value for v in enum_set], + # } app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 593121ba60b..517da4375e1 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -18,15 +18,15 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ModelType +from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.util.devices import torch_dtype from .baseinvocation import ( diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index 983df6b4684..c73aa438096 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -68,7 +68,7 @@ def to_yaml(self) -> str: return OmegaConf.to_yaml(conf) @classmethod - def add_parser_arguments(cls, parser) -> None: + def add_parser_arguments(cls, parser: ArgumentParser) -> None: """Dynamically create arguments for a settings parser.""" if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index 22624a6579a..ec8a453323d 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -29,8 +29,8 @@ """ from abc import ABC, abstractmethod -from contextlib import AbstractContextManager from pathlib import Path +from typing import Iterator from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary @@ -40,18 +40,17 @@ class InvocationStatsServiceBase(ABC): "Abstract base class for recording node memory/time performance statistics" @abstractmethod - def __init__(self): + def __init__(self) -> None: """ Initialize the InvocationStatsService and reset counters to zero """ - pass @abstractmethod def collect_stats( self, invocation: BaseInvocation, graph_execution_state_id: str, - ) -> AbstractContextManager: + ) -> Iterator[None]: """ Return a context object that will capture the statistics on the execution of invocaation. Use with: to place around the part of the code that executes the invocation. @@ -61,7 +60,7 @@ def collect_stats( pass @abstractmethod - def reset_stats(self, graph_execution_state_id: str): + def reset_stats(self, graph_execution_state_id: str) -> None: """ Reset all statistics for the indicated graph. :param graph_execution_state_id: The id of the session whose stats to reset. @@ -70,7 +69,7 @@ def reset_stats(self, graph_execution_state_id: str): pass @abstractmethod - def log_stats(self, graph_execution_state_id: str): + def log_stats(self, graph_execution_state_id: str) -> None: """ Write out the accumulated statistics to the log or somewhere else. :param graph_execution_state_id: The id of the session whose stats to log. diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 2f03db0af72..080219af75e 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -14,7 +14,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase -from invokeai.app.services.events import EventServiceBase +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index f4dd905135a..cc80333e932 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -5,7 +5,7 @@ from typing import Optional from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase @@ -15,23 +15,7 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model_by_key( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's key, load it and return the LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch. - :param context_data: Invocation context data used for event reporting - """ - pass - - @abstractmethod - def load_model_by_config( + def load_model( self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, @@ -44,34 +28,6 @@ def load_model_by_config( :param submodel: For main (pipeline models), the submodel to fetch. :param context_data: Invocation context data used for event reporting """ - pass - - @abstractmethod - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Name of to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context_data: The invocation context data. - - Exceptions: UnknownModelException -- model with these attributes not known - NotImplementedException -- a model loader was not provided at initialization time - ValueError -- more than one model matches this combination - """ @property @abstractmethod diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index fa96a4672d1..15c6283d8af 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -1,15 +1,18 @@ # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team """Implementation of model loader service.""" -from typing import Optional +from typing import Optional, Type from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.invoker import Invoker -from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType -from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType +from invokeai.backend.model_manager.load import ( + LoadedModel, + ModelLoaderRegistry, + ModelLoaderRegistryBase, +) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.util.logging import InvokeAILogger @@ -18,25 +21,23 @@ class ModelLoadService(ModelLoadServiceBase): - """Wrapper around AnyModelLoader.""" + """Wrapper around ModelLoaderRegistry.""" def __init__( - self, - app_config: InvokeAIAppConfig, - record_store: ModelRecordServiceBase, - ram_cache: ModelCacheBase[AnyModel], - convert_cache: ModelConvertCacheBase, + self, + app_config: InvokeAIAppConfig, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, ): """Initialize the model load service.""" logger = InvokeAILogger.get_logger(self.__class__.__name__) logger.setLevel(app_config.log_level.upper()) - self._store = record_store - self._any_loader = AnyModelLoader( - app_config=app_config, - logger=logger, - ram_cache=ram_cache, - convert_cache=convert_cache, - ) + self._logger = logger + self._app_config = app_config + self._ram_cache = ram_cache + self._convert_cache = convert_cache + self._registry = registry def start(self, invoker: Invoker) -> None: self._invoker = invoker @@ -44,63 +45,14 @@ def start(self, invoker: Invoker) -> None: @property def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache used by this loader.""" - return self._any_loader.ram_cache + return self._ram_cache @property def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" - return self._any_loader.convert_cache - - def load_model_by_key( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's key, load it and return the LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting - """ - config = self._store.get_model(key) - return self.load_model_by_config(config, submodel_type, context_data) + return self._convert_cache - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Name of to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - ValueError -- more than one model matches this combination - """ - configs = self._store.search_by_attr(model_name, base_model, model_type) - if len(configs) == 0: - raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") - elif len(configs) > 1: - raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") - else: - return self.load_model_by_key(configs[0].key, submodel) - - def load_model_by_config( + def load_model( self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, @@ -118,7 +70,15 @@ def load_model_by_config( context_data=context_data, model_config=model_config, ) - loaded_model = self._any_loader.load_model(model_config, submodel_type) + + implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore + loaded_model: LoadedModel = implementation( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self._convert_cache, + ).load_model(model_config, submodel_type) + if context_data: self._emit_load_event( context_data=context_data, diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index 66707493f71..5455577266a 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -3,7 +3,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel -from .model_manager_default import ModelManagerServiceBase, ModelManagerService +from .model_manager_default import ModelManagerService, ModelManagerServiceBase __all__ = [ "ModelManagerServiceBase", diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 1116c82ff1f..c25aa6fb47c 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,10 +1,14 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team from abc import ABC, abstractmethod +from typing import Optional from typing_extensions import Self from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.invocation_context import InvocationContextData +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase @@ -65,3 +69,32 @@ def start(self, invoker: Invoker) -> None: @abstractmethod def stop(self, invoker: Invoker) -> None: pass + + @abstractmethod + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + pass + + @abstractmethod + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index b96341be69e..d029f9e0339 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,10 +1,14 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" +from typing import Optional + from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache +from invokeai.app.services.shared.invocation_context import InvocationContextData +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -12,7 +16,7 @@ from ..events.events_base import EventServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase -from ..model_records import ModelRecordServiceBase +from ..model_records import ModelRecordServiceBase, UnknownModelException from .model_manager_base import ModelManagerServiceBase @@ -58,6 +62,56 @@ def stop(self, invoker: Invoker) -> None: if hasattr(service, "stop"): service.stop(invoker) + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + return self.load.load_model(model_config, submodel_type, context_data) + + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + config = self.store.get_model(key) + return self.load.load_model(config, submodel_type, context_data) + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self.store.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load.load_model(configs[0], submodel, context_data) + @classmethod def build_model_manager( cls, @@ -82,9 +136,9 @@ def build_model_manager( ) loader = ModelLoadService( app_config=app_config, - record_store=model_record_service, ram_cache=ram_cache, convert_cache=convert_cache, + registry=ModelLoaderRegistry, ) installer = ModelInstallService( app_config=app_config, diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 089d09f825c..1395427a97e 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -281,7 +281,7 @@ def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> Loaded # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. - return self._services.model_manager.load.load_model_by_key( + return self._services.model_manager.load_model_by_key( key=key, submodel_type=submodel_type, context_data=self._context_data ) @@ -296,7 +296,7 @@ def load_by_attrs( :param model_type: Type of the model :param submodel: For main (pipeline models), the submodel to fetch """ - return self._services.model_manager.load.load_model_by_attr( + return self._services.model_manager.load_model_by_attr( model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py deleted file mode 100644 index e15eb23f5b2..00000000000 --- a/invokeai/backend/install/migrate_to_3.py +++ /dev/null @@ -1,591 +0,0 @@ -""" -Migrate the models directory and models.yaml file from an existing -InvokeAI 2.3 installation to 3.0.0. -""" - -import argparse -import os -import shutil -import warnings -from dataclasses import dataclass -from pathlib import Path -from typing import Union - -import diffusers -import transformers -import yaml -from diffusers import AutoencoderKL, StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from omegaconf import DictConfig, OmegaConf -from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager -from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType - -warnings.filterwarnings("ignore") -transformers.logging.set_verbosity_error() -diffusers.logging.set_verbosity_error() - - -# holder for paths that we will migrate -@dataclass -class ModelPaths: - models: Path - embeddings: Path - loras: Path - controlnets: Path - - -class MigrateTo3(object): - def __init__( - self, - from_root: Path, - to_models: Path, - model_manager: ModelManager, - src_paths: ModelPaths, - ): - self.root_directory = from_root - self.dest_models = to_models - self.mgr = model_manager - self.src_paths = src_paths - - @classmethod - def initialize_yaml(cls, yaml_file: Path): - with open(yaml_file, "w") as file: - file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - def create_directory_structure(self): - """ - Create the basic directory structure for the models folder. - """ - for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: - for model_type in [ - ModelType.Main, - ModelType.Vae, - ModelType.Lora, - ModelType.ControlNet, - ModelType.TextualInversion, - ]: - path = self.dest_models / model_base.value / model_type.value - path.mkdir(parents=True, exist_ok=True) - path = self.dest_models / "core" - path.mkdir(parents=True, exist_ok=True) - - @staticmethod - def copy_file(src: Path, dest: Path): - """ - copy a single file with logging - """ - if dest.exists(): - logger.info(f"Skipping existing {str(dest)}") - return - logger.info(f"Copying {str(src)} to {str(dest)}") - try: - shutil.copy(src, dest) - except Exception as e: - logger.error(f"COPY FAILED: {str(e)}") - - @staticmethod - def copy_dir(src: Path, dest: Path): - """ - Recursively copy a directory with logging - """ - if dest.exists(): - logger.info(f"Skipping existing {str(dest)}") - return - - logger.info(f"Copying {str(src)} to {str(dest)}") - try: - shutil.copytree(src, dest) - except Exception as e: - logger.error(f"COPY FAILED: {str(e)}") - - def migrate_models(self, src_dir: Path): - """ - Recursively walk through src directory, probe anything - that looks like a model, and copy the model into the - appropriate location within the destination models directory. - """ - directories_scanned = set() - for root, dirs, files in os.walk(src_dir, followlinks=True): - for d in dirs: - try: - model = Path(root, d) - info = ModelProbe().heuristic_probe(model) - if not info: - continue - dest = self._model_probe_to_path(info) / model.name - self.copy_dir(model, dest) - directories_scanned.add(model) - except Exception as e: - logger.error(str(e)) - except KeyboardInterrupt: - raise - for f in files: - # don't copy raw learned_embeds.bin or pytorch_lora_weights.bin - # let them be copied as part of a tree copy operation - try: - if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}: - continue - model = Path(root, f) - if model.parent in directories_scanned: - continue - info = ModelProbe().heuristic_probe(model) - if not info: - continue - dest = self._model_probe_to_path(info) / f - self.copy_file(model, dest) - except Exception as e: - logger.error(str(e)) - except KeyboardInterrupt: - raise - - def migrate_support_models(self): - """ - Copy the clipseg, upscaler, and restoration models to their new - locations. - """ - dest_directory = self.dest_models - if (self.root_directory / "models/clipseg").exists(): - self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg") - if (self.root_directory / "models/realesrgan").exists(): - self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan") - for d in ["codeformer", "gfpgan"]: - path = self.root_directory / "models" / d - if path.exists(): - self.copy_dir(path, dest_directory / f"core/face_restoration/{d}") - - def migrate_tuning_models(self): - """ - Migrate the embeddings, loras and controlnets directories to their new homes. - """ - for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]: - if not src: - continue - if src.is_dir(): - logger.info(f"Scanning {src}") - self.migrate_models(src) - else: - logger.info(f"{src} directory not found; skipping") - continue - - def migrate_conversion_models(self): - """ - Migrate all the models that are needed by the ckpt_to_diffusers conversion - script. - """ - - dest_directory = self.dest_models - kwargs = { - "cache_dir": self.root_directory / "models/hub", - # local_files_only = True - } - try: - logger.info("Migrating core tokenizers and text encoders") - target_dir = dest_directory / "core" / "convert" - - self._migrate_pretrained( - BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs - ) - - # sd-1 - repo_id = "openai/clip-vit-large-patch14" - self._migrate_pretrained( - CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs - ) - self._migrate_pretrained( - CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs - ) - - # sd-2 - repo_id = "stabilityai/stable-diffusion-2" - self._migrate_pretrained( - CLIPTokenizer, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-2-clip" / "tokenizer", - **{"subfolder": "tokenizer", **kwargs}, - ) - self._migrate_pretrained( - CLIPTextModel, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-2-clip" / "text_encoder", - **{"subfolder": "text_encoder", **kwargs}, - ) - - # VAE - logger.info("Migrating stable diffusion VAE") - self._migrate_pretrained( - AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs - ) - - # safety checking - logger.info("Migrating safety checker") - repo_id = "CompVis/stable-diffusion-safety-checker" - self._migrate_pretrained( - AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs - ) - self._migrate_pretrained( - StableDiffusionSafetyChecker, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-safety-checker", - **kwargs, - ) - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(str(e)) - - def _model_probe_to_path(self, info: ModelProbeInfo) -> Path: - return Path(self.dest_models, info.base_type.value, info.model_type.value) - - def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs): - if dest.exists() and not force: - logger.info(f"Skipping existing {dest}") - return - model = model_class.from_pretrained(repo_id, **kwargs) - self._save_pretrained(model, dest, overwrite=force) - - def _save_pretrained(self, model, dest: Path, overwrite: bool = False): - model_name = dest.name - if overwrite: - model.save_pretrained(dest, safe_serialization=True) - else: - download_path = dest.with_name(f"{model_name}.downloading") - model.save_pretrained(download_path, safe_serialization=True) - download_path.replace(dest) - - def _download_vae(self, repo_id: str, subfolder: str = None) -> Path: - vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder) - info = ModelProbe().heuristic_probe(vae) - _, model_name = repo_id.split("/") - dest = self._model_probe_to_path(info) / self.unique_name(model_name, info) - vae.save_pretrained(dest, safe_serialization=True) - return dest - - def _vae_path(self, vae: Union[str, dict]) -> Path: - """ - Convert 2.3 VAE stanza to a straight path. - """ - vae_path = None - - # First get a path - if isinstance(vae, str): - vae_path = vae - - elif isinstance(vae, DictConfig): - if p := vae.get("path"): - vae_path = p - elif repo_id := vae.get("repo_id"): - if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded - vae_path = "models/core/convert/sd-vae-ft-mse" - return vae_path - else: - vae_path = self._download_vae(repo_id, vae.get("subfolder")) - - assert vae_path is not None, "Couldn't find VAE for this model" - - # if the VAE is in the old models directory, then we must move it into the new - # one. VAEs outside of this directory can stay where they are. - vae_path = Path(vae_path) - if vae_path.is_relative_to(self.src_paths.models): - info = ModelProbe().heuristic_probe(vae_path) - dest = self._model_probe_to_path(info) / vae_path.name - if not dest.exists(): - if vae_path.is_dir(): - self.copy_dir(vae_path, dest) - else: - self.copy_file(vae_path, dest) - vae_path = dest - - if vae_path.is_relative_to(self.dest_models): - rel_path = vae_path.relative_to(self.dest_models) - return Path("models", rel_path) - else: - return vae_path - - def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config): - """ - Migrate a locally-cached diffusers pipeline identified with a repo_id - """ - dest_dir = self.dest_models - - cache = self.root_directory / "models/hub" - kwargs = { - "cache_dir": cache, - "safety_checker": None, - # local_files_only = True, - } - - owner, repo_name = repo_id.split("/") - model_name = model_name or repo_name - model = cache / "--".join(["models", owner, repo_name]) - - if len(list(model.glob("snapshots/**/model_index.json"))) == 0: - return - revisions = [x.name for x in model.glob("refs/*")] - - # if an fp16 is available we use that - revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0] - pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs) - - info = ModelProbe().heuristic_probe(pipeline) - if not info: - return - - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.") - return - - dest = self._model_probe_to_path(info) / model_name - self._save_pretrained(pipeline, dest) - - rel_path = Path("models", dest.relative_to(dest_dir)) - self._add_model(model_name, info, rel_path, **extra_config) - - def migrate_path(self, location: Path, model_name: str = None, **extra_config): - """ - Migrate a model referred to using 'weights' or 'path' - """ - - # handle relative paths - dest_dir = self.dest_models - location = self.root_directory / location - model_name = model_name or location.stem - - info = ModelProbe().heuristic_probe(location) - if not info: - return - - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.") - return - - # uh oh, weights is in the old models directory - move it into the new one - if Path(location).is_relative_to(self.src_paths.models): - dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name) - if location.is_dir(): - self.copy_dir(location, dest) - else: - self.copy_file(location, dest) - location = Path("models", info.base_type.value, info.model_type.value, location.name) - - self._add_model(model_name, info, location, **extra_config) - - def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config): - if info.model_type != ModelType.Main: - return - - self.mgr.add_model( - model_name=model_name, - base_model=info.base_type, - model_type=info.model_type, - clobber=True, - model_attributes={ - "path": str(location), - "description": f"A {info.base_type.value} {info.model_type.value} model", - "model_format": info.format, - "variant": info.variant_type.value, - **extra_config, - }, - ) - - def migrate_defined_models(self): - """ - Migrate models defined in models.yaml - """ - # find any models referred to in old models.yaml - conf = OmegaConf.load(self.root_directory / "configs/models.yaml") - - for model_name, stanza in conf.items(): - try: - passthru_args = {} - - if vae := stanza.get("vae"): - try: - passthru_args["vae"] = str(self._vae_path(vae)) - except Exception as e: - logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"') - logger.warning(str(e)) - - if config := stanza.get("config"): - passthru_args["config"] = config - - if description := stanza.get("description"): - passthru_args["description"] = description - - if repo_id := stanza.get("repo_id"): - logger.info(f"Migrating diffusers model {model_name}") - self.migrate_repo_id(repo_id, model_name, **passthru_args) - - elif location := stanza.get("weights"): - logger.info(f"Migrating checkpoint model {model_name}") - self.migrate_path(Path(location), model_name, **passthru_args) - - elif location := stanza.get("path"): - logger.info(f"Migrating diffusers model {model_name}") - self.migrate_path(Path(location), model_name, **passthru_args) - - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(str(e)) - - def migrate(self): - self.create_directory_structure() - # the configure script is doing this - self.migrate_support_models() - self.migrate_conversion_models() - self.migrate_tuning_models() - self.migrate_defined_models() - - -def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths: - """ - Returns tuple of (embedding_path, lora_path, controlnet_path) - """ - parser = argparse.ArgumentParser(fromfile_prefix_chars="@") - parser.add_argument( - "--embedding_directory", - "--embedding_path", - type=Path, - dest="embedding_path", - default=Path("embeddings"), - ) - parser.add_argument( - "--lora_directory", - dest="lora_path", - type=Path, - default=Path("loras"), - ) - opt, _ = parser.parse_known_args([f"@{str(initfile)}"]) - return ModelPaths( - models=root / "models", - embeddings=root / str(opt.embedding_path).strip('"'), - loras=root / str(opt.lora_path).strip('"'), - controlnets=root / "controlnets", - ) - - -def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths: - """ - Returns tuple of (embedding_path, lora_path, controlnet_path) - """ - # Don't use the config object because it is unforgiving of version updates - # Just use omegaconf directly - opt = OmegaConf.load(initfile) - paths = opt.InvokeAI.Paths - models = paths.get("models_dir", "models") - embeddings = paths.get("embedding_dir", "embeddings") - loras = paths.get("lora_dir", "loras") - controlnets = paths.get("controlnet_dir", "controlnets") - return ModelPaths( - models=root / models if models else None, - embeddings=root / embeddings if embeddings else None, - loras=root / loras if loras else None, - controlnets=root / controlnets if controlnets else None, - ) - - -def get_legacy_embeddings(root: Path) -> ModelPaths: - path = root / "invokeai.init" - if path.exists(): - return _parse_legacy_initfile(root, path) - path = root / "invokeai.yaml" - if path.exists(): - return _parse_legacy_yamlfile(root, path) - - -def do_migrate(src_directory: Path, dest_directory: Path): - """ - Migrate models from src to dest InvokeAI root directories - """ - config_file = dest_directory / "configs" / "models.yaml.3" - dest_models = dest_directory / "models.3" - - version_3 = (dest_directory / "models" / "core").exists() - - # Here we create the destination models.yaml file. - # If we are writing into a version 3 directory and the - # file already exists, then we write into a copy of it to - # avoid deleting its previous customizations. Otherwise we - # create a new empty one. - if version_3: # write into the dest directory - try: - shutil.copy(dest_directory / "configs" / "models.yaml", config_file) - except Exception: - MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory - (dest_directory / "models").replace(dest_models) - else: - MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) - - paths = get_legacy_embeddings(src_directory) - migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths) - migrator.migrate() - print("Migration successful.") - - if not version_3: - (dest_directory / "models").replace(src_directory / "models.orig") - print(f"Original models directory moved to {dest_directory}/models.orig") - - (dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig") - print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig") - - config_file.replace(config_file.with_suffix("")) - dest_models.replace(dest_models.with_suffix("")) - - -def main(): - parser = argparse.ArgumentParser( - prog="invokeai-migrate3", - description=""" -This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format -'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a - -The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively. -It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure -script, which will perform a full upgrade in place.""", - ) - parser.add_argument( - "--from-directory", - dest="src_root", - type=Path, - required=True, - help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")', - ) - parser.add_argument( - "--to-directory", - dest="dest_root", - type=Path, - required=True, - help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")', - ) - args = parser.parse_args() - src_root = args.src_root - assert src_root.is_dir(), f"{src_root} is not a valid directory" - assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory" - assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory" - assert (src_root / "invokeai.init").exists() or ( - src_root / "invokeai.yaml" - ).exists(), f"{src_root} does not contain an InvokeAI init file." - - dest_root = args.dest_root - assert dest_root.is_dir(), f"{dest_root} is not a valid directory" - config = InvokeAIAppConfig.get_config() - config.parse_args(["--root", str(dest_root)]) - - # TODO: revisit - don't rely on invokeai.yaml to exist yet! - dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists() - if not dest_is_setup: - from invokeai.backend.install.invokeai_configure import initialize_rootdir - - initialize_rootdir(dest_root, True) - - do_migrate(src_root, dest_root) - - -if __name__ == "__main__": - main() diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py deleted file mode 100644 index fdbe714f62c..00000000000 --- a/invokeai/backend/install/model_install_backend.py +++ /dev/null @@ -1,637 +0,0 @@ -""" -Utility (backend) functions used by model_install.py -""" -import os -import re -import shutil -import warnings -from dataclasses import dataclass, field -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Callable, Dict, List, Optional, Set, Union - -import requests -import torch -from diffusers import DiffusionPipeline -from diffusers import logging as dlogging -from huggingface_hub import HfApi, HfFolder, hf_hub_url -from omegaconf import OmegaConf -from tqdm import tqdm - -import invokeai.configs as configs -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType -from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType -from invokeai.backend.util import download_with_resume -from invokeai.backend.util.devices import choose_torch_device, torch_dtype - -from ..util.logging import InvokeAILogger - -warnings.filterwarnings("ignore") - -# --------------------------globals----------------------- -config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger(name="InvokeAI") - -# the initial "configs" dir is now bundled in the `invokeai.configs` package -Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" - -Config_preamble = """ -# This file describes the alternative machine learning models -# available to InvokeAI script. -# -# To add a new model, follow the examples below. Each -# model requires a model config file, a weights file, -# and the width and height of the images it -# was trained on. -""" - -LEGACY_CONFIGS = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v1-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml", - }, - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v2-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", - }, - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - }, -} - - -@dataclass -class InstallSelections: - install_models: List[str] = field(default_factory=list) - remove_models: List[str] = field(default_factory=list) - - -@dataclass -class ModelLoadInfo: - name: str - model_type: ModelType - base_type: BaseModelType - path: Optional[Path] = None - repo_id: Optional[str] = None - subfolder: Optional[str] = None - description: str = "" - installed: bool = False - recommended: bool = False - default: bool = False - requires: Optional[List[str]] = field(default_factory=list) - - -class ModelInstall(object): - def __init__( - self, - config: InvokeAIAppConfig, - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - model_manager: Optional[ModelManager] = None, - access_token: Optional[str] = None, - civitai_api_key: Optional[str] = None, - ): - self.config = config - self.mgr = model_manager or ModelManager(config.model_conf_path) - self.datasets = OmegaConf.load(Dataset_path) - self.prediction_helper = prediction_type_helper - self.access_token = access_token or HfFolder.get_token() - self.civitai_api_key = civitai_api_key or config.civitai_api_key - self.reverse_paths = self._reverse_paths(self.datasets) - - def all_models(self) -> Dict[str, ModelLoadInfo]: - """ - Return dict of model_key=>ModelLoadInfo objects. - This method consolidates and simplifies the entries in both - models.yaml and INITIAL_MODELS.yaml so that they can - be treated uniformly. It also sorts the models alphabetically - by their name, to improve the display somewhat. - """ - model_dict = {} - - # first populate with the entries in INITIAL_MODELS.yaml - for key, value in self.datasets.items(): - name, base, model_type = ModelManager.parse_key(key) - value["name"] = name - value["base_type"] = base - value["model_type"] = model_type - model_info = ModelLoadInfo(**value) - if model_info.subfolder and model_info.repo_id: - model_info.repo_id += f":{model_info.subfolder}" - model_dict[key] = model_info - - # supplement with entries in models.yaml - installed_models = list(self.mgr.list_models()) - - for md in installed_models: - base = md["base_model"] - model_type = md["model_type"] - name = md["model_name"] - key = ModelManager.create_key(name, base, model_type) - if key in model_dict: - model_dict[key].installed = True - else: - model_dict[key] = ModelLoadInfo( - name=name, - base_type=base, - model_type=model_type, - path=value.get("path"), - installed=True, - ) - return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())} - - def _is_autoloaded(self, model_info: dict) -> bool: - path = model_info.get("path") - if not path: - return False - for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]: - if autodir_path := getattr(self.config, autodir): - autodir_path = self.config.root_path / autodir_path - if Path(path).is_relative_to(autodir_path): - return True - return False - - def list_models(self, model_type): - installed = self.mgr.list_models(model_type=model_type) - print() - print(f"Installed models of type `{model_type}`:") - print(f"{'Model Key':50} Model Path") - for i in installed: - print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}") - print() - - # logic here a little reversed to maintain backward compatibility - def starter_models(self, all_models: bool = False) -> Set[str]: - models = set() - for key, _value in self.datasets.items(): - name, base, model_type = ModelManager.parse_key(key) - if all_models or model_type in [ModelType.Main, ModelType.Vae]: - models.add(key) - return models - - def recommended_models(self) -> Set[str]: - starters = self.starter_models(all_models=True) - return {x for x in starters if self.datasets[x].get("recommended", False)} - - def default_model(self) -> str: - starters = self.starter_models() - defaults = [x for x in starters if self.datasets[x].get("default", False)] - return defaults[0] - - def install(self, selections: InstallSelections): - verbosity = dlogging.get_verbosity() # quench NSFW nags - dlogging.set_verbosity_error() - - job = 1 - jobs = len(selections.remove_models) + len(selections.install_models) - - # remove requested models - for key in selections.remove_models: - name, base, mtype = self.mgr.parse_key(key) - logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]") - try: - self.mgr.del_model(name, base, mtype) - except FileNotFoundError as e: - logger.warning(e) - job += 1 - - # add requested models - self._remove_installed(selections.install_models) - self._add_required_models(selections.install_models) - for path in selections.install_models: - logger.info(f"Installing {path} [{job}/{jobs}]") - try: - self.heuristic_import(path) - except (ValueError, KeyError) as e: - logger.error(str(e)) - job += 1 - - dlogging.set_verbosity(verbosity) - self.mgr.commit() - - def heuristic_import( - self, - model_path_id_or_url: Union[str, Path], - models_installed: Set[Path] = None, - ) -> Dict[str, AddModelResult]: - """ - :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL - :param models_installed: Set of installed models, used for recursive invocation - Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. - """ - - if not models_installed: - models_installed = {} - - model_path_id_or_url = str(model_path_id_or_url).strip("\"' ") - - # A little hack to allow nested routines to retrieve info on the requested ID - self.current_id = model_path_id_or_url - path = Path(model_path_id_or_url) - - # fix relative paths - if path.exists() and not path.is_absolute(): - path = path.absolute() # make relative to current WD - - # checkpoint file, or similar - if path.is_file(): - models_installed.update({str(path): self._install_path(path)}) - - # folders style or similar - elif path.is_dir() and any( - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "pytorch_lora_weights.safetensors", - } - ): - models_installed.update({str(model_path_id_or_url): self._install_path(path)}) - - # recursive scan - elif path.is_dir(): - for child in path.iterdir(): - self.heuristic_import(child, models_installed=models_installed) - - # huggingface repo - elif len(str(model_path_id_or_url).split("/")) == 2: - models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) - - # a URL - elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): - models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) - - else: - raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping") - - return models_installed - - def _remove_installed(self, model_list: List[str]): - all_models = self.all_models() - models_to_remove = [] - - for path in model_list: - key = self.reverse_paths.get(path) - if key and all_models[key].installed: - models_to_remove.append(path) - - for path in models_to_remove: - logger.warning(f"{path} already installed. Skipping") - model_list.remove(path) - - def _add_required_models(self, model_list: List[str]): - additional_models = [] - all_models = self.all_models() - for path in model_list: - if not (key := self.reverse_paths.get(path)): - continue - for requirement in all_models[key].requires: - requirement_key = self.reverse_paths.get(requirement) - if not all_models[requirement_key].installed: - additional_models.append(requirement) - model_list.extend(additional_models) - - # install a model from a local path. The optional info parameter is there to prevent - # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult: - info = info or ModelProbe().heuristic_probe(path, self.prediction_helper) - if not info: - logger.warning(f"Unable to parse format of {path}") - return None - model_name = path.stem if path.is_file() else path.name - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - raise ValueError(f'A model named "{model_name}" is already installed.') - attributes = self._make_attributes(path, info) - return self.mgr.add_model( - model_name=model_name, - base_model=info.base_type, - model_type=info.model_type, - model_attributes=attributes, - ) - - def _install_url(self, url: str) -> AddModelResult: - with TemporaryDirectory(dir=self.config.models_path) as staging: - CIVITAI_RE = r".*civitai.com.*" - civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE) - location = download_with_resume( - url, Path(staging), access_token=self.civitai_api_key if civit_url else None - ) - if not location: - logger.error(f"Unable to download {url}. Skipping.") - info = ModelProbe().heuristic_probe(location, self.prediction_helper) - dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name - dest.parent.mkdir(parents=True, exist_ok=True) - models_path = shutil.move(location, dest) - - # staged version will be garbage-collected at this time - return self._install_path(Path(models_path), info) - - def _install_repo(self, repo_id: str) -> AddModelResult: - # hack to recover models stored in subfolders -- - # Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster - subfolder = None - if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id): - repo_id = match.group(1) - subfolder = match.group(2) - - hinfo = HfApi().model_info(repo_id) - - # we try to figure out how to download this most economically - # list all the files in the repo - files = [x.rfilename for x in hinfo.siblings] - if subfolder: - files = [x for x in files if x.startswith(f"{subfolder}/")] - prefix = f"{subfolder}/" if subfolder else "" - - location = None - - with TemporaryDirectory(dir=self.config.models_path) as staging: - staging = Path(staging) - if f"{prefix}model_index.json" in files: - location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline - elif f"{prefix}unet/model.onnx" in files: - location = self._download_hf_model(repo_id, files, staging) - else: - for suffix in ["safetensors", "bin"]: - if f"{prefix}pytorch_lora_weights.{suffix}" in files: - location = self._download_hf_model( - repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder - ) # LoRA - break - elif ( - self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files - ): # vae, controlnet or some other standalone - files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}diffusion_pytorch_model.{suffix}" in files: - files = ["config.json", f"diffusion_pytorch_model.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}learned_embeds.{suffix}" in files: - location = self._download_hf_model( - repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder - ) - break - elif ( - f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files - ): # IP-Adapter - files = ["image_encoder.txt", f"ip_adapter.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files: - # This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted - # by InvokeAI for use with IP-Adapters. - files = ["config.json", f"model.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - if not location: - logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.") - return {} - - info = ModelProbe().heuristic_probe(location, self.prediction_helper) - if not info: - logger.warning(f"Could not probe {location}. Skipping install.") - return {} - dest = ( - self.config.models_path - / info.base_type.value - / info.model_type.value - / self._get_model_name(repo_id, location) - ) - if dest.exists(): - shutil.rmtree(dest) - shutil.copytree(location, dest) - return self._install_path(dest, info) - - def _get_model_name(self, path_name: str, location: Path) -> str: - """ - Calculate a name for the model - primitive implementation. - """ - if key := self.reverse_paths.get(path_name): - (name, base, mtype) = ModelManager.parse_key(key) - return name - elif location.is_dir(): - return location.name - else: - return location.stem - - def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict: - model_name = path.name if path.is_dir() else path.stem - description = f"{info.base_type.value} {info.model_type.value} model {model_name}" - if key := self.reverse_paths.get(self.current_id): - if key in self.datasets: - description = self.datasets[key].get("description") or description - - rel_path = self.relative_to_root(path, self.config.models_path) - - attributes = { - "path": str(rel_path), - "description": str(description), - "model_format": info.format, - } - legacy_conf = None - if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX: - attributes.update( - { - "variant": info.variant_type, - } - ) - if info.format == "checkpoint": - try: - possible_conf = path.with_suffix(".yaml") - if possible_conf.exists(): - legacy_conf = str(self.relative_to_root(possible_conf)) - elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: - legacy_conf = Path( - self.config.legacy_conf_dir, - LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type], - ) - else: - legacy_conf = Path( - self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type] - ) - except KeyError: - legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess - - if info.model_type == ModelType.ControlNet and info.format == "checkpoint": - possible_conf = path.with_suffix(".yaml") - if possible_conf.exists(): - legacy_conf = str(self.relative_to_root(possible_conf)) - else: - legacy_conf = Path( - self.config.root_path, - "configs/controlnet", - ("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"), - ) - - if legacy_conf: - attributes.update({"config": str(legacy_conf)}) - return attributes - - def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path: - root = root or self.config.root_path - if path.is_relative_to(root): - return path.relative_to(root) - else: - return path - - def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path: - """ - Retrieve a StableDiffusion model from cache or remote and then - does a save_pretrained() to the indicated staging area. - """ - _, name = repo_id.split("/") - precision = torch_dtype(choose_torch_device()) - variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"] - - model = None - for variant in variants: - try: - model = DiffusionPipeline.from_pretrained( - repo_id, - variant=variant, - torch_dtype=precision, - safety_checker=None, - subfolder=subfolder, - ) - except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors - if "fp16" not in str(e): - print(e) - - if model: - break - - if not model: - logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") - return None - model.save_pretrained(staging / name, safe_serialization=True) - return staging / name - - def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path: - _, name = repo_id.split("/") - location = staging / name - paths = [] - for filename in files: - filePath = Path(filename) - p = hf_download_with_resume( - repo_id, - model_dir=location / filePath.parent, - model_name=filePath.name, - access_token=self.access_token, - subfolder=filePath.parent / subfolder if subfolder else filePath.parent, - ) - if p: - paths.append(p) - else: - logger.warning(f"Could not download {filename} from {repo_id}.") - - return location if len(paths) > 0 else None - - @classmethod - def _reverse_paths(cls, datasets) -> dict: - """ - Reverse mapping from repo_id/path to destination name. - """ - return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()} - - -# ------------------------------------- -def yes_or_no(prompt: str, default_yes=True): - default = "y" if default_yes else "n" - response = input(f"{prompt} [{default}] ") or default - if default_yes: - return response[0] not in ("n", "N") - else: - return response[0] in ("y", "Y") - - -# --------------------------------------------- -def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): - logger = InvokeAILogger.get_logger("InvokeAI") - logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage()) - - model = model_class.from_pretrained( - model_name, - resume_download=True, - **kwargs, - ) - model.save_pretrained(destination, safe_serialization=True) - return destination - - -# --------------------------------------------- -def hf_download_with_resume( - repo_id: str, - model_dir: str, - model_name: str, - model_dest: Path = None, - access_token: str = None, - subfolder: str = None, -) -> Path: - model_dest = model_dest or Path(os.path.join(model_dir, model_name)) - os.makedirs(model_dir, exist_ok=True) - - url = hf_hub_url(repo_id, model_name, subfolder=subfolder) - - header = {"Authorization": f"Bearer {access_token}"} if access_token else {} - open_mode = "wb" - exist_size = 0 - - if os.path.exists(model_dest): - exist_size = os.path.getsize(model_dest) - header["Range"] = f"bytes={exist_size}-" - open_mode = "ab" - - resp = requests.get(url, headers=header, stream=True) - total = int(resp.headers.get("content-length", 0)) - - if resp.status_code == 416: # "range not satisfiable", which means nothing to return - logger.info(f"{model_name}: complete file found. Skipping.") - return model_dest - elif resp.status_code == 404: - logger.warning("File not found") - return None - elif resp.status_code != 200: - logger.warning(f"{model_name}: {resp.reason}") - elif exist_size > 0: - logger.info(f"{model_name}: partial file found. Resuming...") - else: - logger.info(f"{model_name}: Downloading...") - - try: - with ( - open(model_dest, open_mode) as file, - tqdm( - desc=model_name, - initial=exist_size, - total=total + exist_size, - unit="iB", - unit_scale=True, - unit_divisor=1000, - ) as bar, - ): - for data in resp.iter_content(chunk_size=1024): - size = file.write(data) - bar.update(size) - except Exception as e: - logger.error(f"An error occurred while downloading {model_name}: {str(e)}") - return None - return model_dest diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 3ba6fc5a23c..e51966c779c 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -9,8 +9,8 @@ from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from .resampler import Resampler from ..raw_model import RawModel +from .resampler import Resampler class ImageProjModel(torch.nn.Module): diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index fb0c23067fb..0b7128034a2 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -10,6 +10,7 @@ from typing_extensions import Self from invokeai.backend.model_manager import BaseModelType + from .raw_model import RawModel @@ -366,6 +367,7 @@ def to( AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] + class LoRAModelRaw(RawModel): # (torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] diff --git a/invokeai/backend/model_management_OLD/README.md b/invokeai/backend/model_management_OLD/README.md deleted file mode 100644 index 0d94f39642e..00000000000 --- a/invokeai/backend/model_management_OLD/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Model Cache - -## `glibc` Memory Allocator Fragmentation - -Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM). - -This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`. - -To better understand how the `glibc` memory allocator works, see these references: -- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html -- Details: https://sourceware.org/glibc/wiki/MallocInternals - -Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation. - -We can work around this memory fragmentation issue by setting the following env var: - -```bash -# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed. -MALLOC_MMAP_THRESHOLD_=1048576 -``` - -See the following references for more information about the `malloc` tunable parameters: -- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html -- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html -- https://man7.org/linux/man-pages/man3/mallopt.3.html - -The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected. diff --git a/invokeai/backend/model_management_OLD/__init__.py b/invokeai/backend/model_management_OLD/__init__.py deleted file mode 100644 index d523a7a0c8d..00000000000 --- a/invokeai/backend/model_management_OLD/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# ruff: noqa: I001, F401 -""" -Initialization file for invokeai.backend.model_management -""" -# This import must be first -from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType -from .lora import ModelPatcher, ONNXModelPatcher -from .model_cache import ModelCache - -from .models import ( - BaseModelType, - DuplicateModelException, - ModelNotFoundException, - ModelType, - ModelVariantType, - SubModelType, -) - -# This import must be last -from .model_merge import MergeInterpolationMethod, ModelMerger diff --git a/invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py deleted file mode 100644 index 6878218f679..00000000000 --- a/invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py +++ /dev/null @@ -1,1739 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Adapted for use in InvokeAI by Lincoln Stein, July 2023 -# -""" Conversion script for the Stable Diffusion checkpoints.""" - -import re -from contextlib import nullcontext -from io import BytesIO -from pathlib import Path -from typing import Optional, Union - -import requests -import torch -from diffusers.models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -from diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - UnCLIPScheduler, -) -from diffusers.utils import is_accelerate_available -from picklescan.scanner import scan_file_path -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextConfig, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger - -from .models import BaseModelType, ModelVariantType - -try: - from omegaconf import OmegaConf - from omegaconf.dictconfig import DictConfig -except ImportError: - raise ImportError( - "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." - ) - -if is_accelerate_available(): - from accelerate import init_empty_weights - from accelerate.utils import set_module_tensor_to_device - -logger = InvokeAILogger.get_logger(__name__) -CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert" - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "to_q.weight") - new_item = new_item.replace("q.bias", "to_q.bias") - - new_item = new_item.replace("k.weight", "to_k.weight") - new_item = new_item.replace("k.bias", "to_k.bias") - - new_item = new_item.replace("v.weight", "to_v.weight") - new_item = new_item.replace("v.bias", "to_v.bias") - - new_item = new_item.replace("proj_out.weight", "to_out.0.weight") - new_item = new_item.replace("proj_out.bias", "to_out.0.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) - shape = old_checkpoint[path["old"]].shape - if is_attn_weight and len(shape) == 3: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - elif is_attn_weight and len(shape) == 4: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if controlnet: - unet_params = original_config.model.params.control_stage_config.params - else: - if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: - unet_params = original_config.model.params.unet_config.params - else: - unet_params = original_config.model.params.network_config.params - - vae_params = original_config.model.params.first_stage_config.params.ddconfig - - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for _i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - if unet_params.transformer_depth is not None: - transformer_layers_per_block = ( - unet_params.transformer_depth - if isinstance(unet_params.transformer_depth, int) - else list(unet_params.transformer_depth) - ) - else: - transformer_layers_per_block = 1 - - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) - - head_dim = unet_params.num_heads if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False - ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim_mult = unet_params.model_channels // unet_params.num_head_channels - head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] - - class_embed_type = None - addition_embed_type = None - addition_time_embed_dim = None - projection_class_embeddings_input_dim = None - context_dim = None - - if unet_params.context_dim is not None: - context_dim = ( - unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] - ) - - if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": - if context_dim in [2048, 1280]: - # SDXL - addition_embed_type = "text_time" - addition_time_embed_dim = 256 - else: - class_embed_type = "projection" - assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels - else: - raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") - - config = { - "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, - "down_block_types": tuple(down_block_types), - "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, - "cross_attention_dim": context_dim, - "attention_head_dim": head_dim, - "use_linear_projection": use_linear_projection, - "class_embed_type": class_embed_type, - "addition_embed_type": addition_embed_type, - "addition_time_embed_dim": addition_time_embed_dim, - "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - "transformer_layers_per_block": transformer_layers_per_block, - } - - if controlnet: - config["conditioning_channels"] = unet_params.hint_channels - else: - config["out_channels"] = unet_params.out_channels - config["up_block_types"] = tuple(up_block_types) - - return config - - -def create_vae_diffusers_config(original_config, image_size: int): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim - - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = { - "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, - "down_block_types": tuple(down_block_types), - "up_block_types": tuple(up_block_types), - "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, - } - return config - - -def create_diffusers_schedular(original_config): - schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, - beta_schedule="scaled_linear", - ) - return schedular - - -def create_ldm_bert_config(original_config): - bert_params = original_config.model.parms.cond_stage_config.params - config = LDMBertConfig( - d_model=bert_params.n_embed, - encoder_layers=bert_params.n_layer, - encoder_ffn_dim=bert_params.n_embed * 4, - ) - return config - - -def convert_ldm_unet_checkpoint( - checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False -): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - if skip_extract_state_dict: - unet_state_dict = checkpoint - else: - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." - else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") - logger.warning( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - logger.warning( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - if config["class_embed_type"] is None: - # No parameters to port - ... - elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - else: - raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") - - if config["addition_embed_type"] == "text_time": - new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - if not controlnet: - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - if controlnet: - # conditioning embedding - - orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - orig_index += 2 - - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - keys = list(checkpoint.keys()) - vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def convert_ldm_bert_checkpoint(checkpoint, config): - def _copy_attn_layer(hf_attn_layer, pt_attn_layer): - hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight - hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight - hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight - - hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight - hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias - - def _copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - def _copy_layer(hf_layer, pt_layer): - # copy layer norms - _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) - _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) - - # copy attn - _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) - - # copy MLP - pt_mlp = pt_layer[1][1] - _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) - _copy_linear(hf_layer.fc2, pt_mlp.net[2]) - - def _copy_layers(hf_layers, pt_layers): - for i, hf_layer in enumerate(hf_layers): - if i != 0: - i += i - pt_layer = pt_layers[i : i + 2] - _copy_layer(hf_layer, pt_layer) - - hf_model = LDMBertModel(config).eval() - - # copy embeds - hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight - - # copy layer norm - _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) - - # copy hidden layers - _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) - - _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) - - return hf_model - - -def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): - if text_encoder is None: - config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - text_model = CLIPTextModel(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] - - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] - - if is_accelerate_available(): - for param_name, param in text_model_dict.items(): - set_module_tensor_to_device(text_model, param_name, "cpu", value=param) - else: - text_model.load_state_dict(text_model_dict) - - return text_model - - -textenc_conversion_lst = [ - ("positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("ln_final.weight", "text_model.final_layer_norm.weight"), - ("ln_final.bias", "text_model.final_layer_norm.bias"), - ("text_projection", "text_projection.weight"), -] -textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} - -textenc_transformer_conversion_lst = [ - # (stable-diffusion, HF Diffusers) - ("resblocks.", "text_model.encoder.layers."), - ("ln_1", "layer_norm1"), - ("ln_2", "layer_norm2"), - (".c_fc.", ".fc1."), - (".c_proj.", ".fc2."), - (".attn", ".self_attn"), - ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), -] -protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} -textenc_pattern = re.compile("|".join(protected.keys())) - - -def convert_paint_by_example_checkpoint(checkpoint): - config = CLIPVisionConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - model = PaintByExampleImageEncoder(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # load clip vision - model.model.load_state_dict(text_model_dict) - - # load mapper - keys_mapper = { - k[len("cond_stage_model.mapper.res") :]: v - for k, v in checkpoint.items() - if k.startswith("cond_stage_model.mapper") - } - - MAPPING = { - "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], - "attn.c_proj": ["attn1.to_out.0"], - "ln_1": ["norm1"], - "ln_2": ["norm3"], - "mlp.c_fc": ["ff.net.0.proj"], - "mlp.c_proj": ["ff.net.2"], - } - - mapped_weights = {} - for key, value in keys_mapper.items(): - prefix = key[: len("blocks.i")] - suffix = key.split(prefix)[-1].split(".")[-1] - name = key.split(prefix)[-1].split(suffix)[0][1:-1] - mapped_names = MAPPING[name] - - num_splits = len(mapped_names) - for i, mapped_name in enumerate(mapped_names): - new_name = ".".join([prefix, mapped_name, suffix]) - shape = value.shape[0] // num_splits - mapped_weights[new_name] = value[i * shape : (i + 1) * shape] - - model.mapper.load_state_dict(mapped_weights) - - # load final layer norm - model.final_layer_norm.load_state_dict( - { - "bias": checkpoint["cond_stage_model.final_ln.bias"], - "weight": checkpoint["cond_stage_model.final_ln.weight"], - } - ) - - # load final proj - model.proj_out.load_state_dict( - { - "bias": checkpoint["proj_out.bias"], - "weight": checkpoint["proj_out.weight"], - } - ) - - # load uncond vector - model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) - return model - - -def convert_open_clip_checkpoint( - checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs -): - # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") - # text_model = CLIPTextModelWithProjection.from_pretrained( - # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 - # ) - config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) - - keys = list(checkpoint.keys()) - - keys_to_ignore = [] - if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: - # make sure to remove all keys > 22 - keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] - keys_to_ignore += ["cond_stage_model.model.text_projection"] - - text_model_dict = {} - - if prefix + "text_projection" in checkpoint: - d_model = int(checkpoint[prefix + "text_projection"].shape[0]) - else: - d_model = 1024 - - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") - - for key in keys: - if key in keys_to_ignore: - continue - if key[len(prefix) :] in textenc_conversion_map: - if key.endswith("text_projection"): - value = checkpoint[key].T.contiguous() - else: - value = checkpoint[key] - - text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value - - if key.startswith(prefix + "transformer."): - new_key = key[len(prefix + "transformer.") :] - if new_key.endswith(".in_proj_weight"): - new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] - elif new_key.endswith(".in_proj_bias"): - new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] - else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - - text_model_dict[new_key] = checkpoint[key] - - if is_accelerate_available(): - for param_name, param in text_model_dict.items(): - set_module_tensor_to_device(text_model, param_name, "cpu", value=param) - else: - text_model.load_state_dict(text_model_dict) - - return text_model - - -def stable_unclip_image_encoder(original_config): - """ - Returns the image processor and clip image encoder for the img2img unclip pipeline. - - We currently know of two types of stable unclip models which separately use the clip and the openclip image - encoders. - """ - - image_embedder_config = original_config.model.params.embedder_config - - sd_clip_image_embedder_class = image_embedder_config.target - sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] - - if sd_clip_image_embedder_class == "ClipImageEmbedder": - clip_model_name = image_embedder_config.params.model - - if clip_model_name == "ViT-L/14": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - else: - raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") - - elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": - feature_extractor = CLIPImageProcessor() - # InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K" - ) - else: - raise NotImplementedError( - f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" - ) - - return feature_extractor, image_encoder - - -def stable_unclip_image_noising_components( - original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None -): - """ - Returns the noising components for the img2img and txt2img unclip pipelines. - - Converts the stability noise augmentor into - 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats - 2. a `DDPMScheduler` for holding the noise schedule - - If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. - """ - noise_aug_config = original_config.model.params.noise_aug_config - noise_aug_class = noise_aug_config.target - noise_aug_class = noise_aug_class.split(".")[-1] - - if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": - noise_aug_config = noise_aug_config.params - embedding_dim = noise_aug_config.timestep_dim - max_noise_level = noise_aug_config.noise_schedule_config.timesteps - beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - - image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) - image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) - - if "clip_stats_path" in noise_aug_config: - if clip_stats_path is None: - raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) - clip_mean = clip_mean[None, :] - clip_std = clip_std[None, :] - - clip_stats_state_dict = { - "mean": clip_mean, - "std": clip_std, - } - - image_normalizer.load_state_dict(clip_stats_state_dict) - else: - raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - - return image_normalizer, image_noising_scheduler - - -def convert_controlnet_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=None, - cross_attention_dim=None, - precision: Optional[torch.dtype] = None, -): - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - original_config = ctrlnet_config.copy() - - ctrlnet_config.pop("addition_embed_type") - ctrlnet_config.pop("addition_time_embed_dim") - ctrlnet_config.pop("transformer_layers_per_block") - - if use_linear_projection is not None: - ctrlnet_config["use_linear_projection"] = use_linear_projection - - if cross_attention_dim is not None: - ctrlnet_config["cross_attention_dim"] = cross_attention_dim - - controlnet = ControlNetModel(**ctrlnet_config) - - # Some controlnet ckpt files are distributed independently from the rest of the - # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ - if "time_embed.0.weight" in checkpoint: - skip_extract_state_dict = True - else: - skip_extract_state_dict = False - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, - original_config, - path=checkpoint_path, - extract_ema=extract_ema, - controlnet=True, - skip_extract_state_dict=skip_extract_state_dict, - ) - - controlnet.load_state_dict(converted_ctrl_checkpoint) - - return controlnet.to(precision) - - -def download_from_original_stable_diffusion_ckpt( - checkpoint_path: str, - model_version: BaseModelType, - model_variant: ModelVariantType, - original_config_file: str = None, - image_size: Optional[int] = None, - prediction_type: str = None, - model_type: str = None, - extract_ema: bool = False, - precision: Optional[torch.dtype] = None, - scheduler_type: str = "pndm", - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - stable_unclip: Optional[str] = None, - stable_unclip_prior: Optional[str] = None, - clip_stats_path: Optional[str] = None, - controlnet: Optional[bool] = None, - load_safety_checker: bool = True, - pipeline_class: DiffusionPipeline = None, - local_files_only=False, - vae_path=None, - text_encoder=None, - tokenizer=None, - scan_needed: bool = True, -) -> DiffusionPipeline: - """ - Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` - config file. - - Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the - global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is - recommended that you override the default values and/or supply an `original_config_file` wherever possible. - - Args: - checkpoint_path (`str`): Path to `.ckpt` file. - original_config_file (`str`): - Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically - inferred by looking for a key that only exists in SD2.0 models. - image_size (`int`, *optional*, defaults to 512): - The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 - Base. Use 768 for Stable Diffusion v2. - prediction_type (`str`, *optional*): - The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable - Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. - num_in_channels (`int`, *optional*, defaults to None): - The number of input channels. If `None`, it will be automatically inferred. - scheduler_type (`str`, *optional*, defaults to 'pndm'): - Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", - "ddim"]`. - model_type (`str`, *optional*, defaults to `None`): - The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", - "FrozenCLIPEmbedder", "PaintByExample"]`. - is_img2img (`bool`, *optional*, defaults to `False`): - Whether the model should be loaded as an img2img pipeline. - extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for - checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to - `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for - inference. Non-EMA weights are usually better to continue fine-tuning. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. This is necessary when running stable - diffusion 2.1. - device (`str`, *optional*, defaults to `None`): - The device to use. Pass `None` to determine automatically. - from_safetensors (`str`, *optional*, defaults to `False`): - If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. - load_safety_checker (`bool`, *optional*, defaults to `True`): - Whether to load the safety checker or not. Defaults to `True`. - pipeline_class (`str`, *optional*, defaults to `None`): - The pipeline class to use. Pass `None` to determine automatically. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): - An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) - to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) - variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. - tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): - An instance of - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) - to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if - needed. - precision (`torch.dtype`, *optional*, defauts to `None`): - If not provided the precision will be set to the precision of the original file. - return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. - """ - - # import pipelines here to avoid circular import error when using from_single_file method - from diffusers import ( - LDMTextToImagePipeline, - PaintByExamplePipeline, - StableDiffusionControlNetPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLPipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - ) - - if pipeline_class is None: - pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline - - if prediction_type == "v-prediction": - prediction_type = "v_prediction" - - if from_safetensors: - from safetensors.torch import load_file as safe_load - - checkpoint = safe_load(checkpoint_path, device="cpu") - else: - if scan_needed: - # scan model - scan_result = scan_file_path(checkpoint_path) - if scan_result.infected_files != 0: - raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # Sometimes models don't have the global_step item - if "global_step" in checkpoint: - global_step = checkpoint["global_step"] - else: - logger.debug("global_step key not found in model") - global_step = None - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") - - precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias" - logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}") - precision = precision or checkpoint[precision_probing_key].dtype - - if original_config_file is None: - key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" - key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" - - # model_type = "v1" - config_url = ( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - ) - - if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: - # model_type = "v2" - config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - - if global_step == 110000: - # v2.1 needs to upcast attention - upcast_attention = True - elif key_name_sd_xl_base in checkpoint: - # only base xl has two text embedders - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" - elif key_name_sd_xl_refiner in checkpoint: - # only refiner xl has embedder and one text embedders - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" - - original_config_file = BytesIO(requests.get(config_url).content) - - original_config = OmegaConf.load(original_config_file) - if original_config["model"]["params"].get("use_ema") is not None: - extract_ema = original_config["model"]["params"]["use_ema"] - - if ( - model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1] - and original_config["model"]["params"].get("parameterization") == "v" - ): - prediction_type = "v_prediction" - upcast_attention = True - image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512 - else: - prediction_type = "epsilon" - upcast_attention = False - image_size = 512 - - # Convert the text model. - if ( - model_type is None - and "cond_stage_config" in original_config.model.params - and original_config.model.params.cond_stage_config is not None - ): - model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] - logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") - elif model_type is None and original_config.model.params.network_config is not None: - if original_config.model.params.network_config.params.context_dim == 2048: - model_type = "SDXL" - else: - model_type = "SDXL-Refiner" - if image_size is None: - image_size = 1024 - - if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: - num_in_channels = 9 - elif num_in_channels is None: - num_in_channels = 4 - - if "unet_config" in original_config.model.params: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" - ): - if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` - # as it relies on a brittle global step parameter here - prediction_type = "epsilon" if global_step == 875000 else "v_prediction" - if image_size is None: - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - image_size = 512 if global_step == 875000 else 768 - else: - if prediction_type is None: - prediction_type = "epsilon" - if image_size is None: - image_size = 512 - - if controlnet is None and "control_stage_config" in original_config.model.params: - controlnet = convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema - ) - - num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 - - if model_type in ["SDXL", "SDXL-Refiner"]: - scheduler_dict = { - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "beta_end": 0.012, - "interpolation_type": "linear", - "num_train_timesteps": num_train_timesteps, - "prediction_type": "epsilon", - "sample_max_value": 1.0, - "set_alpha_to_one": False, - "skip_prk_steps": True, - "steps_offset": 1, - "timestep_spacing": "leading", - } - scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) - scheduler_type = "euler" - else: - beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 - beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) - # make sure scheduler works correctly with DDIM - scheduler.register_to_config(clip_sample=False) - - if scheduler_type == "pndm": - config = dict(scheduler.config) - config["skip_prk_steps"] = True - scheduler = PNDMScheduler.from_config(config) - elif scheduler_type == "lms": - scheduler = LMSDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "heun": - scheduler = HeunDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler": - scheduler = EulerDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) - elif scheduler_type == "ddim": - scheduler = scheduler - else: - raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - unet_config["upcast_attention"] = upcast_attention - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - unet = UNet2DConditionModel(**unet_config) - - if is_accelerate_available(): - for param_name, param in converted_unet_checkpoint.items(): - set_module_tensor_to_device(unet, param_name, "cpu", value=param) - else: - unet.load_state_dict(converted_unet_checkpoint) - - # Convert the VAE model. - if vae_path is None: - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - if ( - "model" in original_config - and "params" in original_config.model - and "scale_factor" in original_config.model.params - ): - vae_scaling_factor = original_config.model.params.scale_factor - else: - vae_scaling_factor = 0.18215 # default SD scaling factor - - vae_config["scaling_factor"] = vae_scaling_factor - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - vae = AutoencoderKL(**vae_config) - - if is_accelerate_available(): - for param_name, param in converted_vae_checkpoint.items(): - set_module_tensor_to_device(vae, param_name, "cpu", value=param) - else: - vae.load_state_dict(converted_vae_checkpoint) - else: - vae = AutoencoderKL.from_pretrained(vae_path) - - if model_type == "FrozenOpenCLIPEmbedder": - config_name = "stabilityai/stable-diffusion-2" - config_kwargs = {"subfolder": "text_encoder"} - - text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) - tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") - - if stable_unclip is None: - if controlnet: - pipe = pipeline_class( - vae=vae.to(precision), - text_encoder=text_model.to(precision), - tokenizer=tokenizer, - unet=unet.to(precision), - scheduler=scheduler, - controlnet=controlnet, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - else: - pipe = pipeline_class( - vae=vae.to(precision), - text_encoder=text_model.to(precision), - tokenizer=tokenizer, - unet=unet.to(precision), - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - else: - image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( - original_config, clip_stats_path=clip_stats_path, device=device - ) - - if stable_unclip == "img2img": - feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) - - pipe = StableUnCLIPImg2ImgPipeline( - # image encoding components - feature_extractor=feature_extractor, - image_encoder=image_encoder, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model.to(precision), - unet=unet.to(precision), - scheduler=scheduler, - # vae - vae=vae, - ) - elif stable_unclip == "txt2img": - if stable_unclip_prior is None or stable_unclip_prior == "karlo": - karlo_model = "kakaobrain/karlo-v1-alpha" - prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") - - prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - prior_text_model = CLIPTextModelWithProjection.from_pretrained( - CONVERT_MODEL_ROOT / "clip-vit-large-patch14" - ) - - prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") - prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) - else: - raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") - - pipe = StableUnCLIPPipeline( - # prior components - prior_tokenizer=prior_tokenizer, - prior_text_encoder=prior_text_model, - prior=prior, - prior_scheduler=prior_scheduler, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, - scheduler=scheduler, - # vae - vae=vae, - ) - else: - raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") - elif model_type == "PaintByExample": - vision_model = convert_paint_by_example_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker") - pipe = PaintByExamplePipeline( - vae=vae, - image_encoder=vision_model, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=feature_extractor, - ) - elif model_type == "FrozenCLIPEmbedder": - text_model = convert_ldm_clip_checkpoint( - checkpoint, local_files_only=local_files_only, text_encoder=text_encoder - ) - tokenizer = ( - CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - if tokenizer is None - else tokenizer - ) - - if load_safety_checker: - safety_checker = StableDiffusionSafetyChecker.from_pretrained( - CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" - ) - feature_extractor = AutoFeatureExtractor.from_pretrained( - CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" - ) - else: - safety_checker = None - feature_extractor = None - - if controlnet: - pipe = pipeline_class( - vae=vae.to(precision), - text_encoder=text_model.to(precision), - tokenizer=tokenizer, - unet=unet.to(precision), - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - else: - pipe = pipeline_class( - vae=vae.to(precision), - text_encoder=text_model.to(precision), - tokenizer=tokenizer, - unet=unet.to(precision), - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - elif model_type in ["SDXL", "SDXL-Refiner"]: - if model_type == "SDXL": - tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) - - tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" - tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") - - config_name = tokenizer_name - config_kwargs = {"projection_dim": 1280} - text_encoder_2 = convert_open_clip_checkpoint( - checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs - ) - - pipe = StableDiffusionXLPipeline( - vae=vae.to(precision), - text_encoder=text_encoder.to(precision), - tokenizer=tokenizer, - text_encoder_2=text_encoder_2.to(precision), - tokenizer_2=tokenizer_2, - unet=unet.to(precision), - scheduler=scheduler, - force_zeros_for_empty_prompt=True, - ) - else: - tokenizer = None - text_encoder = None - tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" - tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") - - config_name = tokenizer_name - config_kwargs = {"projection_dim": 1280} - text_encoder_2 = convert_open_clip_checkpoint( - checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs - ) - - pipe = StableDiffusionXLImg2ImgPipeline( - vae=vae.to(precision), - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet.to(precision), - scheduler=scheduler, - requires_aesthetics_score=True, - force_zeros_for_empty_prompt=False, - ) - else: - text_config = create_ldm_bert_config(original_config) - text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) - tokenizer = BertTokenizerFast.from_pretrained(CONVERT_MODEL_ROOT / "bert-base-uncased") - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - - return pipe - - -def download_controlnet_from_original_ckpt( - checkpoint_path: str, - original_config_file: str, - image_size: int = 512, - extract_ema: bool = False, - precision: Optional[torch.dtype] = None, - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - use_linear_projection: Optional[bool] = None, - cross_attention_dim: Optional[bool] = None, - scan_needed: bool = False, -) -> DiffusionPipeline: - if from_safetensors: - from safetensors import safe_open - - checkpoint = {} - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - else: - if scan_needed: - # scan model - scan_result = scan_file_path(checkpoint_path) - if scan_result.infected_files != 0: - raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - # use original precision - precision_probing_key = "input_blocks.0.0.bias" - ckpt_precision = checkpoint[precision_probing_key].dtype - logger.debug(f"original controlnet precision = {ckpt_precision}") - precision = precision or ckpt_precision - - original_config = OmegaConf.load(original_config_file) - - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if "control_stage_config" not in original_config.model.params: - raise ValueError("`control_stage_config` not present in original config") - - controlnet = convert_controlnet_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=use_linear_projection, - cross_attention_dim=cross_attention_dim, - ) - - return controlnet.to(precision) - - -def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: - vae_config = create_vae_diffusers_config(vae_config, image_size=image_size) - - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae - - -def convert_ckpt_to_diffusers( - checkpoint_path: Union[str, Path], - dump_path: Union[str, Path], - use_safetensors: bool = True, - **kwargs, -): - """ - Takes all the arguments of download_from_original_stable_diffusion_ckpt(), - and in addition a path-like object indicating the location of the desired diffusers - model to be written. - """ - pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) - - pipe.save_pretrained( - dump_path, - safe_serialization=use_safetensors, - ) - - -def convert_controlnet_to_diffusers( - checkpoint_path: Union[str, Path], - dump_path: Union[str, Path], - **kwargs, -): - """ - Takes all the arguments of download_controlnet_from_original_ckpt(), - and in addition a path-like object indicating the location of the desired diffusers - model to be written. - """ - pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) - - pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_management_OLD/detect_baked_in_vae.py b/invokeai/backend/model_management_OLD/detect_baked_in_vae.py deleted file mode 100644 index 9118438548d..00000000000 --- a/invokeai/backend/model_management_OLD/detect_baked_in_vae.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team -""" -This module exports the function has_baked_in_sdxl_vae(). -It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE, -which doesn't work properly in fp16 mode. -""" - -import hashlib -from pathlib import Path - -from safetensors.torch import load_file - -SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51" - - -def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool: - """Return true if the checkpoint contains a custom (non SDXL-1.0) VAE.""" - hash = _vae_hash(checkpoint_path) - return hash != SDXL_1_0_VAE_HASH - - -def _vae_hash(checkpoint_path: Path) -> str: - checkpoint = load_file(checkpoint_path, device="cpu") - vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")] - hash = hashlib.new("sha256") - for key in vae_keys: - value = checkpoint[key] - hash.update(bytes(key, "UTF-8")) - hash.update(bytes(str(value), "UTF-8")) - - return hash.hexdigest() diff --git a/invokeai/backend/model_management_OLD/lora.py b/invokeai/backend/model_management_OLD/lora.py deleted file mode 100644 index aed5eb60d57..00000000000 --- a/invokeai/backend/model_management_OLD/lora.py +++ /dev/null @@ -1,582 +0,0 @@ -from __future__ import annotations - -import pickle -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from compel.embeddings_provider import BaseTextualInversionManager -from diffusers.models import UNet2DConditionModel -from safetensors.torch import load_file -from transformers import CLIPTextModel, CLIPTokenizer - -from invokeai.app.shared.models import FreeUConfig -from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init - -from .models.lora import LoRAModel - -""" -loras = [ - (lora_model1, 0.7), - (lora_model2, 0.4), -] -with LoRAHelper.apply_lora_unet(unet, loras): - # unet with applied loras -# unmodified unet - -""" - - -# TODO: rename smth like ModelPatcher and add TI method? -class ModelPatcher: - @staticmethod - def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: - assert "." not in lora_key - - if not lora_key.startswith(prefix): - raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") - - module = model - module_key = "" - key_parts = lora_key[len(prefix) :].split("_") - - submodule_name = key_parts.pop(0) - - while len(key_parts) > 0: - try: - module = module.get_submodule(submodule_name) - module_key += "." + submodule_name - submodule_name = key_parts.pop(0) - except Exception: - submodule_name += "_" + key_parts.pop(0) - - module = module.get_submodule(submodule_name) - module_key = (module_key + "." + submodule_name).lstrip(".") - - return (module_key, module) - - @classmethod - @contextmanager - def apply_lora_unet( - cls, - unet: UNet2DConditionModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(unet, loras, "lora_unet_"): - yield - - @classmethod - @contextmanager - def apply_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(text_encoder, loras, "lora_te_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(text_encoder, loras, "lora_te1_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder2( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(text_encoder, loras, "lora_te2_"): - yield - - @classmethod - @contextmanager - def apply_lora( - cls, - model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw - prefix: str, - ): - original_weights = {} - try: - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue - - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - if module_key not in original_weights: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device="cpu") - - if module.weight.shape != layer_weight.shape: - # TODO: debug on lycoris - layer_weight = layer_weight.reshape(module.weight.shape) - - module.weight += layer_weight.to(dtype=dtype) - - yield # wait for context manager exit - - finally: - with torch.no_grad(): - for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight) - - @classmethod - @contextmanager - def apply_ti( - cls, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - ti_list: List[Tuple[str, Any]], - ) -> Tuple[CLIPTokenizer, TextualInversionManager]: - init_tokens_count = None - new_tokens_added = None - - # TODO: This is required since Transformers 4.32 see - # https://github.com/huggingface/transformers/pull/25088 - # More information by NVIDIA: - # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc - # This value might need to be changed in the future and take the GPUs model into account as there seem - # to be ideal values for different GPUS. This value is temporary! - # For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817 - pad_to_multiple_of = 8 - - try: - # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a - # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after - # exiting this `apply_ti(...)` context manager. - # - # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, - # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). - ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) - ti_manager = TextualInversionManager(ti_tokenizer) - init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings - - def _get_trigger(ti_name, index): - trigger = ti_name - if index > 0: - trigger += f"-!pad-{i}" - return f"<{trigger}>" - - def _get_ti_embedding(model_embeddings, ti): - print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}") - print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}") - # for SDXL models, select the embedding that matches the text encoder's dimensions - if ti.embedding_2 is not None: - return ( - ti.embedding_2 - if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] - else ti.embedding - ) - else: - print(f"DEBUG: ti.embedding={type(ti.embedding)}") - return ti.embedding - - # modify tokenizer - new_tokens_added = 0 - for ti_name, ti in ti_list: - ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) - - for i in range(ti_embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) - - # Modify text_encoder. - # resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of - # this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some - # time. - with skip_torch_weight_init(): - text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) - model_embeddings = text_encoder.get_input_embeddings() - - for ti_name, ti in ti_list: - ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) - - ti_tokens = [] - for i in range(ti_embedding.shape[0]): - embedding = ti_embedding[i] - trigger = _get_trigger(ti_name, i) - - token_id = ti_tokenizer.convert_tokens_to_ids(trigger) - if token_id == ti_tokenizer.unk_token_id: - raise RuntimeError(f"Unable to find token id for token '{trigger}'") - - if model_embeddings.weight.data[token_id].shape != embedding.shape: - raise ValueError( - f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" - f" {embedding.shape[0]}, but the current model has token dimension" - f" {model_embeddings.weight.data[token_id].shape[0]}." - ) - - model_embeddings.weight.data[token_id] = embedding.to( - device=text_encoder.device, dtype=text_encoder.dtype - ) - ti_tokens.append(token_id) - - if len(ti_tokens) > 1: - ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] - - yield ti_tokenizer, ti_manager - - finally: - if init_tokens_count and new_tokens_added: - text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of) - - @classmethod - @contextmanager - def apply_clip_skip( - cls, - text_encoder: CLIPTextModel, - clip_skip: int, - ): - skipped_layers = [] - try: - for _i in range(clip_skip): - skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1)) - - yield - - finally: - while len(skipped_layers) > 0: - text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) - - @classmethod - @contextmanager - def apply_freeu( - cls, - unet: UNet2DConditionModel, - freeu_config: Optional[FreeUConfig] = None, - ): - did_apply_freeu = False - try: - if freeu_config is not None: - unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) - did_apply_freeu = True - - yield - - finally: - if did_apply_freeu: - unet.disable_freeu() - - -class TextualInversionModel: - embedding: torch.Tensor # [n, 768]|[n, 1280] - embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - if not isinstance(file_path, Path): - file_path = Path(file_path) - - result = cls() # TODO: - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - # both v1 and v2 format embeddings - # difference mostly in metadata - if "string_to_param" in state_dict: - if len(state_dict["string_to_param"]) > 1: - print( - f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', - " token will be used.", - ) - - result.embedding = next(iter(state_dict["string_to_param"].values())) - - # v3 (easynegative) - elif "emb_params" in state_dict: - result.embedding = state_dict["emb_params"] - - # v5(sdxl safetensors file) - elif "clip_g" in state_dict and "clip_l" in state_dict: - result.embedding = state_dict["clip_g"] - result.embedding_2 = state_dict["clip_l"] - - # v4(diffusers bin files) - else: - result.embedding = next(iter(state_dict.values())) - - if len(result.embedding.shape) == 1: - result.embedding = result.embedding.unsqueeze(0) - - if not isinstance(result.embedding, torch.Tensor): - raise ValueError(f"Invalid embeddings file: {file_path.name}") - - return result - - -class TextualInversionManager(BaseTextualInversionManager): - pad_tokens: Dict[int, List[int]] - tokenizer: CLIPTokenizer - - def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = {} - self.tokenizer = tokenizer - - def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: - if len(self.pad_tokens) == 0: - return token_ids - - if token_ids[0] == self.tokenizer.bos_token_id: - raise ValueError("token_ids must not start with bos_token_id") - if token_ids[-1] == self.tokenizer.eos_token_id: - raise ValueError("token_ids must not end with eos_token_id") - - new_token_ids = [] - for token_id in token_ids: - new_token_ids.append(token_id) - if token_id in self.pad_tokens: - new_token_ids.extend(self.pad_tokens[token_id]) - - # Do not exceed the max model input size - # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), - # which first removes and then adds back the start and end tokens. - max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 - if len(new_token_ids) > max_length: - new_token_ids = new_token_ids[0:max_length] - - return new_token_ids - - -class ONNXModelPatcher: - from diffusers import OnnxRuntimeModel - - from .models.base import IAIOnnxRuntimeModel - - @classmethod - @contextmanager - def apply_lora_unet( - cls, - unet: OnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(unet, loras, "lora_unet_"): - yield - - @classmethod - @contextmanager - def apply_lora_text_encoder( - cls, - text_encoder: OnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(text_encoder, loras, "lora_te_"): - yield - - # based on - # https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323 - @classmethod - @contextmanager - def apply_lora( - cls, - model: IAIOnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], - prefix: str, - ): - from .models.base import IAIOnnxRuntimeModel - - if not isinstance(model, IAIOnnxRuntimeModel): - raise Exception("Only IAIOnnxRuntimeModel models supported") - - orig_weights = {} - - try: - blended_loras = {} - - for lora, lora_weight in loras: - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue - - layer.to(dtype=torch.float32) - layer_key = layer_key.replace(prefix, "") - # TODO: rewrite to pass original tensor weight(required by ia3) - layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight - if layer_key is blended_loras: - blended_loras[layer_key] += layer_weight - else: - blended_loras[layer_key] = layer_weight - - node_names = {} - for node in model.nodes.values(): - node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name - - for layer_key, lora_weight in blended_loras.items(): - conv_key = layer_key + "_Conv" - gemm_key = layer_key + "_Gemm" - matmul_key = layer_key + "_MatMul" - - if conv_key in node_names or gemm_key in node_names: - if conv_key in node_names: - conv_node = model.nodes[node_names[conv_key]] - else: - conv_node = model.nodes[node_names[gemm_key]] - - weight_name = [n for n in conv_node.input if ".weight" in n][0] - orig_weight = model.tensors[weight_name] - - if orig_weight.shape[-2:] == (1, 1): - if lora_weight.shape[-2:] == (1, 1): - new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2)) - else: - new_weight = orig_weight.squeeze((3, 2)) + lora_weight - - new_weight = np.expand_dims(new_weight, (2, 3)) - else: - if orig_weight.shape != lora_weight.shape: - new_weight = orig_weight + lora_weight.reshape(orig_weight.shape) - else: - new_weight = orig_weight + lora_weight - - orig_weights[weight_name] = orig_weight - model.tensors[weight_name] = new_weight.astype(orig_weight.dtype) - - elif matmul_key in node_names: - weight_node = model.nodes[node_names[matmul_key]] - matmul_name = [n for n in weight_node.input if "MatMul" in n][0] - - orig_weight = model.tensors[matmul_name] - new_weight = orig_weight + lora_weight.transpose() - - orig_weights[matmul_name] = orig_weight - model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype) - - else: - # warn? err? - pass - - yield - - finally: - # restore original weights - for name, orig_weight in orig_weights.items(): - model.tensors[name] = orig_weight - - @classmethod - @contextmanager - def apply_ti( - cls, - tokenizer: CLIPTokenizer, - text_encoder: IAIOnnxRuntimeModel, - ti_list: List[Tuple[str, Any]], - ) -> Tuple[CLIPTokenizer, TextualInversionManager]: - from .models.base import IAIOnnxRuntimeModel - - if not isinstance(text_encoder, IAIOnnxRuntimeModel): - raise Exception("Only IAIOnnxRuntimeModel models supported") - - orig_embeddings = None - - try: - # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a - # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after - # exiting this `apply_ti(...)` context manager. - # - # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, - # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). - ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) - ti_manager = TextualInversionManager(ti_tokenizer) - - def _get_trigger(ti_name, index): - trigger = ti_name - if index > 0: - trigger += f"-!pad-{i}" - return f"<{trigger}>" - - # modify text_encoder - orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] - - # modify tokenizer - new_tokens_added = 0 - for ti_name, ti in ti_list: - if ti.embedding_2 is not None: - ti_embedding = ( - ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding - ) - else: - ti_embedding = ti.embedding - - for i in range(ti_embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) - - embeddings = np.concatenate( - (np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))), - axis=0, - ) - - for ti_name, _ in ti_list: - ti_tokens = [] - for i in range(ti_embedding.shape[0]): - embedding = ti_embedding[i].detach().numpy() - trigger = _get_trigger(ti_name, i) - - token_id = ti_tokenizer.convert_tokens_to_ids(trigger) - if token_id == ti_tokenizer.unk_token_id: - raise RuntimeError(f"Unable to find token id for token '{trigger}'") - - if embeddings[token_id].shape != embedding.shape: - raise ValueError( - f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" - f" {embedding.shape[0]}, but the current model has token dimension" - f" {embeddings[token_id].shape[0]}." - ) - - embeddings[token_id] = embedding - ti_tokens.append(token_id) - - if len(ti_tokens) > 1: - ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] - - text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype( - orig_embeddings.dtype - ) - - yield ti_tokenizer, ti_manager - - finally: - # restore - if orig_embeddings is not None: - text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings diff --git a/invokeai/backend/model_management_OLD/memory_snapshot.py b/invokeai/backend/model_management_OLD/memory_snapshot.py deleted file mode 100644 index fe54af191ce..00000000000 --- a/invokeai/backend/model_management_OLD/memory_snapshot.py +++ /dev/null @@ -1,99 +0,0 @@ -import gc -from typing import Optional - -import psutil -import torch - -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 - -GB = 2**30 # 1 GB - - -class MemorySnapshot: - """A snapshot of RAM and VRAM usage. All values are in bytes.""" - - def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]): - """Initialize a MemorySnapshot. - - Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`. - - Args: - process_ram (int): CPU RAM used by the current process. - vram (Optional[int]): VRAM used by torch. - malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil. - """ - self.process_ram = process_ram - self.vram = vram - self.malloc_info = malloc_info - - @classmethod - def capture(cls, run_garbage_collector: bool = True): - """Capture and return a MemorySnapshot. - - Note: This function has significant overhead, particularly if `run_garbage_collector == True`. - - Args: - run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM - usage. Defaults to True. - - Returns: - MemorySnapshot - """ - if run_garbage_collector: - gc.collect() - - # According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is - # supported on all platforms. - process_ram = psutil.Process().memory_info().rss - - if torch.cuda.is_available(): - vram = torch.cuda.memory_allocated() - else: - # TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have - # time to test it properly. - vram = None - - try: - malloc_info = LibcUtil().mallinfo2() - except (OSError, AttributeError): - # OSError: This is expected in environments that do not have the 'libc.so.6' shared library. - # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) - # TODO: Does `mallinfo` work? - malloc_info = None - - return cls(process_ram, vram, malloc_info) - - -def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str: - """Get a pretty string describing the difference between two `MemorySnapshot`s.""" - - def get_msg_line(prefix: str, val1: int, val2: int): - diff = val2 - val1 - return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n" - - msg = "" - - if snapshot_1 is None or snapshot_2 is None: - return msg - - msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram) - - if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: - msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd) - - msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks) - - msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks) - - libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd - libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd - msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2) - - libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd - libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd - msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2) - - if snapshot_1.vram is not None and snapshot_2.vram is not None: - msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - - return msg diff --git a/invokeai/backend/model_management_OLD/model_cache.py b/invokeai/backend/model_management_OLD/model_cache.py deleted file mode 100644 index 2a7f4b5a95e..00000000000 --- a/invokeai/backend/model_management_OLD/model_cache.py +++ /dev/null @@ -1,553 +0,0 @@ -""" -Manage a RAM cache of diffusion/transformer models for fast switching. -They are moved between GPU VRAM and CPU RAM as necessary. If the cache -grows larger than a preset maximum, then the least recently used -model will be cleared and (re)loaded from disk when next needed. - -The cache returns context manager generators designed to load the -model into the GPU within the context, and unload outside the -context. Use like this: - - cache = ModelCache(max_cache_size=7.5) - with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, - cache.get_model('stabilityai/stable-diffusion-2') as SD2: - do_something_in_GPU(SD1,SD2) - - -""" - -import gc -import hashlib -import math -import os -import sys -import time -from contextlib import suppress -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, Optional, Type, Union, types - -import torch - -import invokeai.backend.util.logging as logger -from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff -from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init - -from ..util.devices import choose_torch_device -from .models import BaseModelType, ModelBase, ModelType, SubModelType - -if choose_torch_device() == torch.device("mps"): - from torch import mps - -# Maximum size of the cache, in gigs -# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously -DEFAULT_MAX_CACHE_SIZE = 6.0 - -# amount of GPU memory to hold in reserve for use by generations (GB) -DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 - -# actual size of a gig -GIG = 1073741824 -# Size of a MB in bytes. -MB = 2**20 - - -@dataclass -class CacheStats(object): - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - # {submodel_key => size} - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) - - -class ModelLocker(object): - "Forward declaration" - - pass - - -class ModelCache(object): - "Forward declaration" - - pass - - -class _CacheRecord: - size: int - model: Any - cache: ModelCache - _locks: int - - def __init__(self, cache, model: Any, size: int): - self.size = size - self.model = model - self.cache = cache - self._locks = 0 - - def lock(self): - self._locks += 1 - - def unlock(self): - self._locks -= 1 - assert self._locks >= 0 - - @property - def locked(self): - return self._locks > 0 - - @property - def loaded(self): - if self.model is not None and hasattr(self.model, "device"): - return self.model.device != self.cache.storage_device - else: - return False - - -class ModelCache(object): - def __init__( - self, - max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, - max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, - execution_device: torch.device = torch.device("cuda"), - storage_device: torch.device = torch.device("cpu"), - precision: torch.dtype = torch.float16, - sequential_offload: bool = False, - lazy_offloading: bool = True, - sha_chunksize: int = 16777216, - logger: types.ModuleType = logger, - log_memory_usage: bool = False, - ): - """ - :param max_cache_size: Maximum size of the RAM cache [6.0 GB] - :param execution_device: Torch device to load active model into [torch.device('cuda')] - :param storage_device: Torch device to save inactive model in [torch.device('cpu')] - :param precision: Precision for loaded models [torch.float16] - :param lazy_offloading: Keep model in VRAM until another model needs to be loaded - :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially - :param sha_chunksize: Chunksize to use when calculating sha256 model hash - :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache - operation, and the result will be logged (at debug level). There is a time cost to capturing the memory - snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's - behaviour. - """ - self.model_infos: Dict[str, ModelBase] = {} - # allow lazy offloading only when vram cache enabled - self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 - self.precision: torch.dtype = precision - self.max_cache_size: float = max_cache_size - self.max_vram_cache_size: float = max_vram_cache_size - self.execution_device: torch.device = execution_device - self.storage_device: torch.device = storage_device - self.sha_chunksize = sha_chunksize - self.logger = logger - self._log_memory_usage = log_memory_usage - - # used for stats collection - self.stats = None - - self._cached_models = {} - self._cache_stack = [] - - def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: - if self._log_memory_usage: - return MemorySnapshot.capture() - return None - - def get_key( - self, - model_path: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ): - key = f"{model_path}:{base_model}:{model_type}" - if submodel_type: - key += f":{submodel_type}" - return key - - def _get_model_info( - self, - model_path: str, - model_class: Type[ModelBase], - base_model: BaseModelType, - model_type: ModelType, - ): - model_info_key = self.get_key( - model_path=model_path, - base_model=base_model, - model_type=model_type, - submodel_type=None, - ) - - if model_info_key not in self.model_infos: - self.model_infos[model_info_key] = model_class( - model_path, - base_model, - model_type, - ) - - return self.model_infos[model_info_key] - - # TODO: args - def get_model( - self, - model_path: Union[str, Path], - model_class: Type[ModelBase], - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - gpu_load: bool = True, - ) -> Any: - if not isinstance(model_path, Path): - model_path = Path(model_path) - - if not os.path.exists(model_path): - raise Exception(f"Model not found: {model_path}") - - model_info = self._get_model_info( - model_path=model_path, - model_class=model_class, - base_model=base_model, - model_type=model_type, - ) - key = self.get_key( - model_path=model_path, - base_model=base_model, - model_type=model_type, - submodel_type=submodel, - ) - # TODO: lock for no copies on simultaneous calls? - cache_entry = self._cached_models.get(key, None) - if cache_entry is None: - self.logger.info( - f"Loading model {model_path}, type" - f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}" - ) - if self.stats: - self.stats.misses += 1 - - self_reported_model_size_before_load = model_info.get_size(submodel) - # Remove old models from the cache to make room for the new model. - self._make_cache_room(self_reported_model_size_before_load) - - # Load the model from disk and capture a memory snapshot before/after. - start_load_time = time.time() - snapshot_before = self._capture_memory_snapshot() - with skip_torch_weight_init(): - model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) - snapshot_after = self._capture_memory_snapshot() - end_load_time = time.time() - - self_reported_model_size_after_load = model_info.get_size(submodel) - - self.logger.debug( - f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n" - f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /" - f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB: - self.logger.debug( - f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:" - f" {(self_reported_model_size_before_load/GIG):.2f}GB /" - f" {(self_reported_model_size_after_load/GIG):.2f}GB." - ) - - cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load) - self._cached_models[key] = cache_entry - else: - if self.stats: - self.stats.hits += 1 - - if self.stats: - self.stats.cache_size = self.max_cache_size * GIG - self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size()) - self.stats.in_cache = len(self._cached_models) - self.stats.loaded_model_sizes[key] = max( - self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel) - ) - - with suppress(Exception): - self._cache_stack.remove(key) - self._cache_stack.append(key) - - return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size) - - def _move_model_to_device(self, key: str, target_device: torch.device): - cache_entry = self._cached_models[key] - - source_device = cache_entry.model.device - # Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support - # multi-GPU. - if torch.device(source_device).type == torch.device(target_device).type: - return - - start_model_to_time = time.time() - snapshot_before = self._capture_memory_snapshot() - cache_entry.model.to(target_device) - snapshot_after = self._capture_memory_snapshot() - end_model_to_time = time.time() - self.logger.debug( - f"Moved model '{key}' from {source_device} to" - f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" - f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - if ( - snapshot_before is not None - and snapshot_after is not None - and snapshot_before.vram is not None - and snapshot_after.vram is not None - ): - vram_change = abs(snapshot_before.vram - snapshot_after.vram) - - # If the estimated model size does not match the change in VRAM, log a warning. - if not math.isclose( - vram_change, - cache_entry.size, - rel_tol=0.1, - abs_tol=10 * MB, - ): - self.logger.debug( - f"Moving model '{key}' from {source_device} to" - f" {target_device} caused an unexpected change in VRAM usage. The model's" - " estimated size may be incorrect. Estimated model size:" - f" {(cache_entry.size/GIG):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - class ModelLocker(object): - def __init__(self, cache, key, model, gpu_load, size_needed): - """ - :param cache: The model_cache object - :param key: The key of the model to lock in GPU - :param model: The model to lock - :param gpu_load: True if load into gpu - :param size_needed: Size of the model to load - """ - self.gpu_load = gpu_load - self.cache = cache - self.key = key - self.model = model - self.size_needed = size_needed - self.cache_entry = self.cache._cached_models[self.key] - - def __enter__(self) -> Any: - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this - # code to move it into GPU! - if self.gpu_load: - self.cache_entry.lock() - - try: - if self.cache.lazy_offloading: - self.cache._offload_unlocked_models(self.size_needed) - - self.cache._move_model_to_device(self.key, self.cache.execution_device) - - self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}") - self.cache._print_cuda_stats() - - except Exception: - self.cache_entry.unlock() - raise - - # TODO: not fully understand - # in the event that the caller wants the model in RAM, we - # move it into CPU if it is in GPU and not locked - elif self.cache_entry.loaded and not self.cache_entry.locked: - self.cache._move_model_to_device(self.key, self.cache.storage_device) - - return self.model - - def __exit__(self, type, value, traceback): - if not hasattr(self.model, "to"): - return - - self.cache_entry.unlock() - if not self.cache.lazy_offloading: - self.cache._offload_unlocked_models() - self.cache._print_cuda_stats() - - # TODO: should it be called untrack_model? - def uncache_model(self, cache_id: str): - with suppress(ValueError): - self._cache_stack.remove(cache_id) - self._cached_models.pop(cache_id, None) - - def model_hash( - self, - model_path: Union[str, Path], - ) -> str: - """ - Given the HF repo id or path to a model on disk, returns a unique - hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs - - :param model_path: Path to model file/directory on disk. - """ - return self._local_model_hash(model_path) - - def cache_size(self) -> float: - """Return the current size of the cache, in GB.""" - return self._cache_size() / GIG - - def _has_cuda(self) -> bool: - return self.execution_device.type == "cuda" - - def _print_cuda_stats(self): - vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) - ram = "%4.2fG" % self.cache_size() - - cached_models = 0 - loaded_models = 0 - locked_models = 0 - for model_info in self._cached_models.values(): - cached_models += 1 - if model_info.loaded: - loaded_models += 1 - if model_info.locked: - locked_models += 1 - - self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" - f" {cached_models}/{loaded_models}/{locked_models}" - ) - - def _cache_size(self) -> int: - return sum([m.size for m in self._cached_models.values()]) - - def _make_cache_room(self, model_size): - # calculate how much memory this model will require - # multiplier = 2 if self.precision==torch.float32 else 1 - bytes_needed = model_size - maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes - current_size = self._cache_size() - - if current_size + bytes_needed > maximum_size: - self.logger.debug( - f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" - f" {(bytes_needed/GIG):.2f} GB" - ) - - self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") - - pos = 0 - models_cleared = 0 - while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): - model_key = self._cache_stack[pos] - cache_entry = self._cached_models[model_key] - - refs = sys.getrefcount(cache_entry.model) - - # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly - # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: - # https://docs.python.org/3/library/gc.html#gc.get_referrers - - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately - if refs > 2: - while True: - cleared = False - for referrer in gc.get_referrers(cache_entry.model): - if type(referrer).__name__ == "frame": - # RuntimeError: cannot clear an executing frame - with suppress(RuntimeError): - referrer.clear() - cleared = True - # break - - # repeat if referrers changes(due to frame clear), else exit loop - if cleared: - gc.collect() - else: - break - - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None - self.logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," - f" refs: {refs}" - ) - - # Expected refs: - # 1 from cache_entry - # 1 from getrefcount function - # 1 from onnx runtime object - if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): - self.logger.debug( - f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" - ) - current_size -= cache_entry.size - models_cleared += 1 - if self.stats: - self.stats.cleared += 1 - del self._cache_stack[pos] - del self._cached_models[model_key] - del cache_entry - - else: - pos += 1 - - if models_cleared > 0: - # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but - # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost - # is high even if no garbage gets collected.) - # - # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: - # - If models had to be cleared, it's a signal that we are close to our memory limit. - # - If models were cleared, there's a good chance that there's a significant amount of garbage to be - # collected. - # - # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up - # immediately when their reference count hits 0. - gc.collect() - - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() - - self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") - - def _offload_unlocked_models(self, size_needed: int = 0): - reserved = self.max_vram_cache_size * GIG - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") - for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): - if vram_in_use <= reserved: - break - if not cache_entry.locked and cache_entry.loaded: - self._move_model_to_device(model_key, self.storage_device) - - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") - - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() - - def _local_model_hash(self, model_path: Union[str, Path]) -> str: - sha = hashlib.sha256() - path = Path(model_path) - - hashpath = path / "checksum.sha256" - if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime: - with open(hashpath) as f: - hash = f.read() - return hash - - self.logger.debug(f"computing hash of model {path.name}") - for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")): - with open(file, "rb") as f: - while chunk := f.read(self.sha_chunksize): - sha.update(chunk) - hash = sha.hexdigest() - with open(hashpath, "w") as f: - f.write(hash) - return hash diff --git a/invokeai/backend/model_management_OLD/model_load_optimizations.py b/invokeai/backend/model_management_OLD/model_load_optimizations.py deleted file mode 100644 index a46d262175f..00000000000 --- a/invokeai/backend/model_management_OLD/model_load_optimizations.py +++ /dev/null @@ -1,30 +0,0 @@ -from contextlib import contextmanager - -import torch - - -def _no_op(*args, **kwargs): - pass - - -@contextmanager -def skip_torch_weight_init(): - """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) - to skip weight initialization. - - By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular - distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is - completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager - monkey-patches common torch layers to skip the weight initialization step. - """ - torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] - saved_functions = [m.reset_parameters for m in torch_modules] - - try: - for torch_module in torch_modules: - torch_module.reset_parameters = _no_op - - yield None - finally: - for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): - torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_management_OLD/model_manager.py b/invokeai/backend/model_management_OLD/model_manager.py deleted file mode 100644 index 84d93f15fa8..00000000000 --- a/invokeai/backend/model_management_OLD/model_manager.py +++ /dev/null @@ -1,1121 +0,0 @@ -"""This module manages the InvokeAI `models.yaml` file, mapping -symbolic diffusers model names to the paths and repo_ids used by the -underlying `from_pretrained()` call. - -SYNOPSIS: - - mgr = ModelManager('/home/phi/invokeai/configs/models.yaml') - sd1_5 = mgr.get_model('stable-diffusion-v1-5', - model_type=ModelType.Main, - base_model=BaseModelType.StableDiffusion1, - submodel_type=SubModelType.Unet) - with sd1_5 as unet: - run_some_inference(unet) - -FETCHING MODELS: - -Models are described using four attributes: - - 1) model_name -- the symbolic name for the model - - 2) ModelType -- an enum describing the type of the model. Currently - defined types are: - ModelType.Main -- a full model capable of generating images - ModelType.Vae -- a VAE model - ModelType.Lora -- a LoRA or LyCORIS fine-tune - ModelType.TextualInversion -- a textual inversion embedding - ModelType.ControlNet -- a ControlNet model - ModelType.IPAdapter -- an IPAdapter model - - 3) BaseModelType -- an enum indicating the stable diffusion base model, one of: - BaseModelType.StableDiffusion1 - BaseModelType.StableDiffusion2 - - 4) SubModelType (optional) -- an enum that refers to one of the submodels contained - within the main model. Values are: - - SubModelType.UNet - SubModelType.TextEncoder - SubModelType.Tokenizer - SubModelType.Scheduler - SubModelType.SafetyChecker - -To fetch a model, use `manager.get_model()`. This takes the symbolic -name of the model, the ModelType, the BaseModelType and the -SubModelType. The latter is required for ModelType.Main. - -get_model() will return a ModelInfo object that can then be used in -context to retrieve the model and move it into GPU VRAM (on GPU -systems). - -A typical example is: - - sd1_5 = mgr.get_model('stable-diffusion-v1-5', - model_type=ModelType.Main, - base_model=BaseModelType.StableDiffusion1, - submodel_type=SubModelType.UNet) - with sd1_5 as unet: - run_some_inference(unet) - -The ModelInfo object provides a number of useful fields describing the -model, including: - - name -- symbolic name of the model - base_model -- base model (BaseModelType) - type -- model type (ModelType) - location -- path to the model file - precision -- torch precision of the model - hash -- unique sha256 checksum for this model - -SUBMODELS: - -When fetching a main model, you must specify the submodel. Retrieval -of full pipelines is not supported. - - vae_info = mgr.get_model('stable-diffusion-1.5', - model_type = ModelType.Main, - base_model = BaseModelType.StableDiffusion1, - submodel_type = SubModelType.Vae - ) - with vae_info as vae: - do_something(vae) - -This rule does not apply to controlnets, embeddings, loras and standalone -VAEs, which do not have submodels. - -LISTING MODELS - -The model_names() method will return a list of Tuples describing each -model it knows about: - - >> mgr.model_names() - [ - ('stable-diffusion-1.5', , ), - ('stable-diffusion-2.1', , ), - ('inpaint', , ) - ('Ink scenery', , ) - ... - ] - -The tuple is in the correct order to pass to get_model(): - - for m in mgr.model_names(): - info = get_model(*m) - -In contrast, the list_models() method returns a list of dicts, each -providing information about a model defined in models.yaml. For example: - - >>> models = mgr.list_models() - >>> json.dumps(models[0]) - {"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny", - "model_format": "diffusers", - "name": "canny", - "base_model": "sd-1", - "type": "controlnet" - } - -You can filter by model type and base model as shown here: - - - controlnets = mgr.list_models(model_type=ModelType.ControlNet, - base_model=BaseModelType.StableDiffusion1) - for c in controlnets: - name = c['name'] - format = c['model_format'] - path = c['path'] - type = c['type'] - # etc - -ADDING AND REMOVING MODELS - -At startup time, the `models` directory will be scanned for -checkpoints, diffusers pipelines, controlnets, LoRAs and TI -embeddings. New entries will be added to the model manager and defunct -ones removed. Anything that is a main model (ModelType.Main) will be -added to models.yaml. For scanning to succeed, files need to be in -their proper places. For example, a controlnet folder built on the -stable diffusion 2 base, will need to be placed in -`models/sd-2/controlnet`. - -Layout of the `models` directory: - - models - ├── sd-1 - │ ├── controlnet - │ ├── lora - │ ├── main - │ └── embedding - ├── sd-2 - │ ├── controlnet - │ ├── lora - │ ├── main - │ └── embedding - └── core - ├── face_reconstruction - │ ├── codeformer - │ └── gfpgan - ├── sd-conversion - │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs - │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs - │ └── stable-diffusion-safety-checker - └── upscaling - └─── esrgan - - - -class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are not listed -explicitly in models.yaml, but are added to the in-memory data -structure at initialization time by scanning the models directory. The -in-memory data structure can be resynchronized by calling -`manager.scan_models_directory()`. - -Files and folders placed inside the `autoimport` paths (paths -defined in `invokeai.yaml`) will also be scanned for new models at -initialization time and added to `models.yaml`. Files will not be -moved from this location but preserved in-place. These directories -are: - - configuration default description - ------------- ------- ----------- - autoimport_dir autoimport/main main models - lora_dir autoimport/lora LoRA/LyCORIS models - embedding_dir autoimport/embedding TI embeddings - controlnet_dir autoimport/controlnet ControlNet models - -In actuality, models located in any of these directories are scanned -to determine their type, so it isn't strictly necessary to organize -the different types in this way. This entry in `invokeai.yaml` will -recursively scan all subdirectories within `autoimport`, scan models -files it finds, and import them if recognized. - - Paths: - autoimport_dir: autoimport - -A model can be manually added using `add_model()` using the model's -name, base model, type and a dict of model attributes. See -`invokeai/backend/model_management/models` for the attributes required -by each model type. - -A model can be deleted using `del_model()`, providing the same -identifying information as `get_model()` - -The `heuristic_import()` method will take a set of strings -corresponding to local paths, remote URLs, and repo_ids, probe the -object to determine what type of model it is (if any), and import new -models into the manager. If passed a directory, it will recursively -scan it for models to import. The return value is a set of the models -successfully added. - -MODELS.YAML - -The general format of a models.yaml section is: - - type-of-model/name-of-model: - path: /path/to/local/file/or/directory - description: a description - format: diffusers|checkpoint - variant: normal|inpaint|depth - -The type of model is given in the stanza key, and is one of -{main, vae, lora, controlnet, textual} - -The format indicates whether the model is organized as a diffusers -folder with model subdirectories, or is contained in a single -checkpoint or safetensors file. - -The path points to a file or directory on disk. If a relative path, -the root is the InvokeAI ROOTDIR. - -""" -from __future__ import annotations - -import hashlib -import os -import textwrap -import types -from dataclasses import dataclass -from pathlib import Path -from shutil import move, rmtree -from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union, cast - -import torch -import yaml -from omegaconf import OmegaConf -from omegaconf.dictconfig import DictConfig -from pydantic import BaseModel, ConfigDict, Field - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util import CUDA_DEVICE, Chdir - -from .model_cache import ModelCache, ModelLocker -from .model_search import ModelSearch -from .models import ( - MODEL_CLASSES, - BaseModelType, - DuplicateModelException, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelError, - ModelNotFoundException, - ModelType, - SchedulerPredictionType, - SubModelType, -) - -# We are only starting to number the config file with release 3. -# The config file version doesn't have to start at release version, but it will help -# reduce confusion. -CONFIG_FILE_VERSION = "3.0.0" - - -@dataclass -class LoadedModelInfo: - context: ModelLocker - name: str - base_model: BaseModelType - type: ModelType - hash: str - location: Union[Path, str] - precision: torch.dtype - _cache: Optional[ModelCache] = None - - def __enter__(self): - return self.context.__enter__() - - def __exit__(self, *args, **kwargs): - self.context.__exit__(*args, **kwargs) - - -class AddModelResult(BaseModel): - name: str = Field(description="The name of the model after installation") - model_type: ModelType = Field(description="The type of model") - base_model: BaseModelType = Field(description="The base model") - config: ModelConfigBase = Field(description="The configuration of the model") - - model_config = ConfigDict(protected_namespaces=()) - - -MAX_CACHE_SIZE = 6.0 # GB - - -class ConfigMeta(BaseModel): - version: str - - -class ModelManager(object): - """ - High-level interface to model management. - """ - - logger: types.ModuleType = logger - - def __init__( - self, - config: Union[Path, DictConfig, str], - device_type: torch.device = CUDA_DEVICE, - precision: torch.dtype = torch.float16, - max_cache_size=MAX_CACHE_SIZE, - sequential_offload=False, - logger: types.ModuleType = logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - self.config_path = None - if isinstance(config, (str, Path)): - self.config_path = Path(config) - if not self.config_path.exists(): - logger.warning(f"The file {self.config_path} was not found. Initializing a new file") - self.initialize_model_config(self.config_path) - config = OmegaConf.load(self.config_path) - - elif not isinstance(config, DictConfig): - raise ValueError("config argument must be an OmegaConf object, a Path or a string") - - self.config_meta = ConfigMeta(**config.pop("__metadata__")) - # TODO: metadata not found - # TODO: version check - - self.app_config = InvokeAIAppConfig.get_config() - self.logger = logger - self.cache = ModelCache( - max_cache_size=max_cache_size, - max_vram_cache_size=self.app_config.vram_cache_size, - lazy_offloading=self.app_config.lazy_offload, - execution_device=device_type, - precision=precision, - sequential_offload=sequential_offload, - logger=logger, - log_memory_usage=self.app_config.log_memory_usage, - ) - - self._read_models(config) - - def _read_models(self, config: Optional[DictConfig] = None): - if not config: - if self.config_path: - config = OmegaConf.load(self.config_path) - else: - return - - self.models = {} - for model_key, model_config in config.items(): - if model_key.startswith("_"): - continue - model_name, base_model, model_type = self.parse_key(model_key) - model_class = self._get_implementation(base_model, model_type) - # alias for config file - model_config["model_format"] = model_config.pop("format") - self.models[model_key] = model_class.create_config(**model_config) - - # check config version number and update on disk/RAM if necessary - self.cache_keys = {} - - # add controlnet, lora and textual_inversion models from disk - self.scan_models_directory() - - def sync_to_config(self): - """ - Call this when `models.yaml` has been changed externally. - This will reinitialize internal data structures - """ - # Reread models directory; note that this will reinitialize the cache, - # causing otherwise unreferenced models to be removed from memory - self._read_models() - - def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: - """ - Given a model name, returns True if it is a valid identifier. - - :param model_name: symbolic name of the model in models.yaml - :param model_type: ModelType enum indicating the type of model to return - :param base_model: BaseModelType enum indicating the base model used by this model - :param rescan: if True, scan_models_directory - """ - model_key = self.create_key(model_name, base_model, model_type) - exists = model_key in self.models - - # if model not found try to find it (maybe file just pasted) - if rescan and not exists: - self.scan_models_directory(base_model=base_model, model_type=model_type) - exists = self.model_exists(model_name, base_model, model_type, rescan=False) - - return exists - - @classmethod - def create_key( - cls, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> str: - # In 3.11, the behavior of (str,enum) when interpolated into a - # string has changed. The next two lines are defensive. - base_model = BaseModelType(base_model) - model_type = ModelType(model_type) - return f"{base_model.value}/{model_type.value}/{model_name}" - - @classmethod - def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]: - base_model_str, model_type_str, model_name = model_key.split("/", 2) - try: - model_type = ModelType(model_type_str) - except Exception: - raise Exception(f"Unknown model type: {model_type_str}") - - try: - base_model = BaseModelType(base_model_str) - except Exception: - raise Exception(f"Unknown base model: {base_model_str}") - - return (model_name, base_model, model_type) - - def _get_model_cache_path(self, model_path): - return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest()) - - @classmethod - def initialize_model_config(cls, config_path: Path): - """Create empty config file""" - with open(config_path, "w") as yaml_file: - yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ) -> LoadedModelInfo: - """Given a model named identified in models.yaml, return - an ModelInfo object describing it. - :param model_name: symbolic name of the model in models.yaml - :param model_type: ModelType enum indicating the type of model to return - :param base_model: BaseModelType enum indicating the base model used by this model - :param submodel_type: an ModelType enum indicating the portion of - the model to retrieve (e.g. ModelType.Vae) - """ - model_key = self.create_key(model_name, base_model, model_type) - - if not self.model_exists(model_name, base_model, model_type, rescan=True): - raise ModelNotFoundException(f"Model not found - {model_key}") - - model_config = self._get_model_config(base_model, model_name, model_type) - - model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) - - if is_submodel_override: - model_type = submodel_type - submodel_type = None - - model_class = self._get_implementation(base_model, model_type) - - if not model_path.exists(): - if model_class.save_to_config: - self.models[model_key].error = ModelError.NotFound - raise Exception(f'Files for model "{model_key}" not found at {model_path}') - - else: - self.models.pop(model_key, None) - raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}') - - # TODO: path - # TODO: is it accurate to use path as id - dst_convert_path = self._get_model_cache_path(model_path) - - model_path = model_class.convert_if_required( - base_model=base_model, - model_path=str(model_path), # TODO: refactor str/Path types logic - output_path=dst_convert_path, - config=model_config, - ) - - model_context = self.cache.get_model( - model_path=model_path, - model_class=model_class, - base_model=base_model, - model_type=model_type, - submodel_type=submodel_type, - ) - - if model_key not in self.cache_keys: - self.cache_keys[model_key] = set() - self.cache_keys[model_key].add(model_context.key) - - model_hash = "" # TODO: - - return LoadedModelInfo( - context=model_context, - name=model_name, - base_model=base_model, - type=submodel_type or model_type, - hash=model_hash, - location=model_path, # TODO: - precision=self.cache.precision, - _cache=self.cache, - ) - - def _get_model_path( - self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None - ) -> (Path, bool): - """Extract a model's filesystem path from its config. - - :return: The fully qualified Path of the module (or submodule). - """ - model_path = model_config.path - is_submodel_override = False - - # Does the config explicitly override the submodel? - if submodel_type is not None and hasattr(model_config, submodel_type): - submodel_path = getattr(model_config, submodel_type) - if submodel_path is not None and len(submodel_path) > 0: - model_path = getattr(model_config, submodel_type) - is_submodel_override = True - - model_path = self.resolve_model_path(model_path) - return model_path, is_submodel_override - - def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase: - """Get a model's config object.""" - model_key = self.create_key(model_name, base_model, model_type) - try: - model_config = self.models[model_key] - except KeyError: - raise ModelNotFoundException(f"Model not found - {model_key}") - return model_config - - def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: - """Get the concrete implementation class for a specific model type.""" - model_class = MODEL_CLASSES[base_model][model_type] - return model_class - - def _instantiate( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ) -> ModelBase: - """Make a new instance of this model, without loading it.""" - model_config = self._get_model_config(base_model, model_name, model_type) - model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) - # FIXME: do non-overriden submodels get the right class? - constructor = self._get_implementation(base_model, model_type) - instance = constructor(model_path, base_model, model_type) - return instance - - def model_info( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> Union[dict, None]: - """ - Given a model name returns the OmegaConf (dict-like) object describing it. - """ - model_key = self.create_key(model_name, base_model, model_type) - if model_key in self.models: - return self.models[model_key].model_dump(exclude_defaults=True) - else: - return None # TODO: None or empty dict on not found - - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Return a list of (str, BaseModelType, ModelType) corresponding to all models - known to the configuration. - """ - return [(self.parse_key(x)) for x in self.models.keys()] - - def list_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> Union[dict, None]: - """ - Returns a dict describing one installed model, using - the combined format of the list_models() method. - """ - models = self.list_models(base_model, model_type, model_name) - if len(models) >= 1: - return models[0] - else: - return None - - def list_models( - self, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, - model_name: Optional[str] = None, - ) -> list[dict]: - """ - Return a list of models. - """ - - model_keys = ( - [self.create_key(model_name, base_model, model_type)] - if model_name and base_model and model_type - else sorted(self.models, key=str.casefold) - ) - models = [] - for model_key in model_keys: - model_config = self.models.get(model_key) - if not model_config: - self.logger.error(f"Unknown model {model_name}") - raise ModelNotFoundException(f"Unknown model {model_name}") - - cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) - if base_model is not None and cur_base_model != base_model: - continue - if model_type is not None and cur_model_type != model_type: - continue - - model_dict = dict( - **model_config.model_dump(exclude_defaults=True), - # OpenAPIModelInfoBase - model_name=cur_model_name, - base_model=cur_base_model, - model_type=cur_model_type, - ) - - # expose paths as absolute to help web UI - if path := model_dict.get("path"): - model_dict["path"] = str(self.resolve_model_path(path)) - models.append(model_dict) - - return models - - def print_models(self) -> None: - """ - Print a table of models and their descriptions. This needs to be redone - """ - # TODO: redo - for model_dict in self.list_models(): - for _model_name, model_info in model_dict.items(): - line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}' - print(line) - - # Tested - LS - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model. - """ - model_key = self.create_key(model_name, base_model, model_type) - model_cfg = self.models.pop(model_key, None) - - if model_cfg is None: - raise ModelNotFoundException(f"Unknown model {model_key}") - - # note: it not garantie to release memory(model can has other references) - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - # if model inside invoke models folder - delete files - model_path = self.resolve_model_path(model_cfg.path) - cache_path = self._get_model_cache_path(model_path) - if cache_path.exists(): - rmtree(str(cache_path)) - - if model_path.is_relative_to(self.app_config.models_path): - if model_path.is_dir(): - rmtree(str(model_path)) - else: - model_path.unlink() - self.commit() - - # LS: tested - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory and the - method will return True. Will fail with an assertion error if provided - attributes are incorrect or the model name is missing. - - The returned dict has the same format as the dict returned by - model_info(). - """ - # relativize paths as they go in - this makes it easier to move the models directory around - if path := model_attributes.get("path"): - model_attributes["path"] = str(self.relative_model_path(Path(path))) - - model_class = self._get_implementation(base_model, model_type) - model_config = model_class.create_config(**model_attributes) - model_key = self.create_key(model_name, base_model, model_type) - - if model_key in self.models and not clobber: - raise Exception(f'Attempt to overwrite existing model definition "{model_key}"') - - old_model = self.models.pop(model_key, None) - if old_model is not None: - # TODO: if path changed and old_model.path inside models folder should we delete this too? - - # remove conversion cache as config changed - old_model_path = self.resolve_model_path(old_model.path) - old_model_cache = self._get_model_cache_path(old_model_path) - if old_model_cache.exists(): - if old_model_cache.is_dir(): - rmtree(str(old_model_cache)) - else: - old_model_cache.unlink() - - # remove in-memory cache - # note: it not guaranteed to release memory(model can has other references) - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - self.models[model_key] = model_config - self.commit() - - return AddModelResult( - name=model_name, - model_type=model_type, - base_model=base_model, - config=model_config, - ) - - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, - ) -> None: - """ - Rename or rebase a model. - """ - if new_name is None and new_base is None: - self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.") - return - - model_key = self.create_key(model_name, base_model, model_type) - model_cfg = self.models.get(model_key, None) - if not model_cfg: - raise ModelNotFoundException(f"Unknown model: {model_key}") - - old_path = self.resolve_model_path(model_cfg.path) - new_name = new_name or model_name - new_base = new_base or base_model - new_key = self.create_key(new_name, new_base, model_type) - if new_key in self.models: - raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"') - - # if this is a model file/directory that we manage ourselves, we need to move it - if old_path.is_relative_to(self.app_config.models_path): - # keep the suffix! - if old_path.is_file(): - new_name = Path(new_name).with_suffix(old_path.suffix).as_posix() - new_path = self.resolve_model_path( - Path( - BaseModelType(new_base).value, - ModelType(model_type).value, - new_name, - ) - ) - move(old_path, new_path) - model_cfg.path = str(new_path.relative_to(self.app_config.models_path)) - - # clean up caches - old_model_cache = self._get_model_cache_path(old_path) - if old_model_cache.exists(): - if old_model_cache.is_dir(): - rmtree(str(old_model_cache)) - else: - old_model_cache.unlink() - - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - self.models.pop(model_key, None) # delete - self.models[new_key] = model_cfg - self.commit() - - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - dest_directory: Optional[Path] = None, - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - - This will raise a ValueError unless the model is a checkpoint. - """ - info = self.model_info(model_name, base_model, model_type) - - if info is None: - raise FileNotFoundError(f"model not found: {model_name}") - - if info["model_format"] != "checkpoint": - raise ValueError(f"not a checkpoint format model: {model_name}") - - # We are taking advantage of a side effect of get_model() that converts check points - # into cached diffusers directories stored at `location`. It doesn't matter - # what submodeltype we request here, so we get the smallest. - submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {} - model = self.get_model( - model_name, - base_model, - model_type, - **submodel, - ) - checkpoint_path = self.resolve_model_path(info["path"]) - old_diffusers_path = self.resolve_model_path(model.location) - new_diffusers_path = ( - dest_directory or self.app_config.models_path / base_model.value / model_type.value - ) / model_name - if new_diffusers_path.exists(): - raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") - - try: - move(old_diffusers_path, new_diffusers_path) - info["model_format"] = "diffusers" - info["path"] = ( - str(new_diffusers_path) - if dest_directory - else str(new_diffusers_path.relative_to(self.app_config.models_path)) - ) - info.pop("config") - - result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True) - except Exception: - # something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error! - rmtree(new_diffusers_path) - raise - - if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path): - checkpoint_path.unlink() - - return result - - def resolve_model_path(self, path: Union[Path, str]) -> Path: - """return relative paths based on configured models_path""" - return self.app_config.models_path / path - - def relative_model_path(self, model_path: Path) -> Path: - if model_path.is_relative_to(self.app_config.models_path): - model_path = model_path.relative_to(self.app_config.models_path) - return model_path - - def search_models(self, search_folder): - self.logger.info(f"Finding Models In: {search_folder}") - models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") - models_folder_safetensors = Path(search_folder).glob("**/*.safetensors") - - ckpt_files = [x for x in models_folder_ckpt if x.is_file()] - safetensor_files = [x for x in models_folder_safetensors if x.is_file()] - - files = ckpt_files + safetensor_files - - found_models = [] - for file in files: - location = str(file.resolve()).replace("\\", "/") - if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location: - found_models.append({"name": file.stem, "location": location}) - - return search_folder, found_models - - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - """ - data_to_save = {} - data_to_save["__metadata__"] = self.config_meta.model_dump() - - for model_key, model_config in self.models.items(): - model_name, base_model, model_type = self.parse_key(model_key) - model_class = self._get_implementation(base_model, model_type) - if model_class.save_to_config: - # TODO: or exclude_unset better fits here? - data_to_save[model_key] = cast(BaseModel, model_config).model_dump( - exclude_defaults=True, exclude={"error"}, mode="json" - ) - # alias for config file - data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format") - - yaml_str = OmegaConf.to_yaml(data_to_save) - config_file_path = conf_file or self.config_path - assert config_file_path is not None, "no config file path to write to" - config_file_path = self.app_config.root_path / config_file_path - tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp") - try: - with open(tmpfile, "w", encoding="utf-8") as outfile: - outfile.write(self.preamble()) - outfile.write(yaml_str) - os.replace(tmpfile, config_file_path) - except OSError as err: - self.logger.warning(f"Could not modify the config file at {config_file_path}") - self.logger.warning(err) - - def preamble(self) -> str: - """ - Returns the preamble for the config file. - """ - return textwrap.dedent( - """ - # This file describes the alternative machine learning models - # available to InvokeAI script. - # - # To add a new model, follow the examples below. Each - # model requires a model config file, a weights file, - # and the width and height of the images it - # was trained on. - """ - ) - - def scan_models_directory( - self, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, - ): - loaded_files = set() - new_models_found = False - - self.logger.info(f"Scanning {self.app_config.models_path} for new models") - with Chdir(self.app_config.models_path): - for model_key, model_config in list(self.models.items()): - model_name, cur_base_model, cur_model_type = self.parse_key(model_key) - - # Patch for relative path bug in older models.yaml - paths should not - # be starting with a hard-coded 'models'. This will also fix up - # models.yaml when committed. - if model_config.path.startswith("models"): - model_config.path = str(Path(*Path(model_config.path).parts[1:])) - - model_path = self.resolve_model_path(model_config.path).absolute() - if not model_path.exists(): - model_class = self._get_implementation(cur_base_model, cur_model_type) - if model_class.save_to_config: - model_config.error = ModelError.NotFound - self.models.pop(model_key, None) - else: - self.models.pop(model_key, None) - else: - loaded_files.add(model_path) - - for cur_base_model in BaseModelType: - if base_model is not None and cur_base_model != base_model: - continue - - for cur_model_type in ModelType: - if model_type is not None and cur_model_type != model_type: - continue - model_class = self._get_implementation(cur_base_model, cur_model_type) - models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value)) - - if not models_dir.exists(): - continue # TODO: or create all folders? - - for model_path in models_dir.iterdir(): - if model_path not in loaded_files: # TODO: check - if model_path.name.startswith("."): - continue - model_name = model_path.name if model_path.is_dir() else model_path.stem - model_key = self.create_key(model_name, cur_base_model, cur_model_type) - - try: - if model_key in self.models: - raise DuplicateModelException(f"Model with key {model_key} added twice") - - model_path = self.relative_model_path(model_path) - model_config: ModelConfigBase = model_class.probe_config( - str(model_path), model_base=cur_base_model - ) - self.models[model_key] = model_config - new_models_found = True - except DuplicateModelException as e: - self.logger.warning(e) - except InvalidModelException as e: - self.logger.warning(f"Not a valid model: {model_path}. {e}") - except NotImplementedError as e: - self.logger.warning(e) - except Exception as e: - self.logger.warning(f"Error loading model {model_path}. {e}") - - imported_models = self.scan_autoimport_directory() - if (new_models_found or imported_models) and self.config_path: - self.commit() - - def scan_autoimport_directory(self) -> Dict[str, AddModelResult]: - """ - Scan the autoimport directory (if defined) and import new models, delete defunct models. - """ - # avoid circular import - from invokeai.backend.install.model_install_backend import ModelInstall - from invokeai.frontend.install.model_install import ask_user_for_prediction_type - - class ScanAndImport(ModelSearch): - def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): - super().__init__(directories, logger) - self.installer = installer - self.ignore = ignore - - def on_search_started(self): - self.new_models_found = {} - - def on_model_found(self, model: Path): - if model not in self.ignore: - self.new_models_found.update(self.installer.heuristic_import(model)) - - def on_search_completed(self): - self.logger.info( - f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models" - ) - - def models_found(self): - return self.new_models_found - - config = self.app_config - - # LS: hacky - # Patch in the SD VAE from core so that it is available for use by the UI - try: - self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))}) - except Exception: - pass - - installer = ModelInstall( - config=self.app_config, - model_manager=self, - prediction_type_helper=ask_user_for_prediction_type, - ) - known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()} - directories = { - config.root_path / x - for x in [ - config.autoimport_dir, - config.lora_dir, - config.embedding_dir, - config.controlnet_dir, - ] - if x - } - scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer) - scanner.search() - - return scanner.models_found() - - def heuristic_import( - self, - items_to_import: Set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> Dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - - May return the following exceptions: - - ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL - - ValueError - a corresponding model already exists - """ - # avoid circular import here - from invokeai.backend.install.model_install_backend import ModelInstall - - successfully_installed = {} - - installer = ModelInstall( - config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self - ) - for thing in items_to_import: - installed = installer.heuristic_import(thing) - successfully_installed.update(installed) - self.commit() - return successfully_installed diff --git a/invokeai/backend/model_management_OLD/model_merge.py b/invokeai/backend/model_management_OLD/model_merge.py deleted file mode 100644 index a9f0a23618e..00000000000 --- a/invokeai/backend/model_management_OLD/model_merge.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -invokeai.backend.model_management.model_merge exports: -merge_diffusion_models() -- combine multiple models by location and return a pipeline object -merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml - -Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team -""" - -import warnings -from enum import Enum -from pathlib import Path -from typing import List, Optional, Union - -from diffusers import DiffusionPipeline -from diffusers import logging as dlogging - -import invokeai.backend.util.logging as logger - -from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType - - -class MergeInterpolationMethod(str, Enum): - WeightedSum = "weighted_sum" - Sigmoid = "sigmoid" - InvSigmoid = "inv_sigmoid" - AddDifference = "add_difference" - - -class ModelMerger(object): - def __init__(self, manager: ModelManager): - self.manager = manager - - def merge_diffusion_models( - self, - model_paths: List[Path], - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - **kwargs, - ) -> DiffusionPipeline: - """ - :param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - - pipe = DiffusionPipeline.from_pretrained( - model_paths[0], - custom_pipeline="checkpoint_merger", - ) - merged_pipe = pipe.merge( - pretrained_model_name_or_path_list=model_paths, - alpha=alpha, - interp=interp.value if interp else None, # diffusers API treats None as "weighted sum" - force=force, - **kwargs, - ) - dlogging.set_verbosity(verbosity) - return merged_pipe - - def merge_diffusion_models_and_save( - self, - model_names: List[str], - base_model: Union[BaseModelType, str], - merged_model_name: str, - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = None, - **kwargs, - ) -> AddModelResult: - """ - :param models: up to three models, designated by their InvokeAI models.yaml model name - :param base_model: base model (must be the same for all merged models!) - :param merged_model_name: name for new model - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - model_paths = [] - config = self.manager.app_config - base_model = BaseModelType(base_model) - vae = None - - for mod in model_names: - info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) - assert info, f"model {mod}, base_model {base_model}, is unknown" - assert ( - info["model_format"] == "diffusers" - ), f"{mod} is not a diffusers model. It must be optimized before merging" - assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" - assert ( - len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference - ), "When merging three models, only the 'add_difference' merge method is supported" - # pick up the first model's vae - if mod == model_names[0]: - vae = info.get("vae") - model_paths.extend([(config.root_path / info["path"]).as_posix()]) - - merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) - logger.debug(f"interp = {interp}, merge_method={merge_method}") - merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs) - dump_path = ( - Path(merge_dest_directory) - if merge_dest_directory - else config.models_path / base_model.value / ModelType.Main.value - ) - dump_path.mkdir(parents=True, exist_ok=True) - dump_path = (dump_path / merged_model_name).as_posix() - - merged_pipe.save_pretrained(dump_path, safe_serialization=True) - attributes = { - "path": dump_path, - "description": f"Merge of models {', '.join(model_names)}", - "model_format": "diffusers", - "variant": ModelVariantType.Normal.value, - "vae": vae, - } - return self.manager.add_model( - merged_model_name, - base_model=base_model, - model_type=ModelType.Main, - model_attributes=attributes, - clobber=True, - ) diff --git a/invokeai/backend/model_management_OLD/model_probe.py b/invokeai/backend/model_management_OLD/model_probe.py deleted file mode 100644 index 74b1b72d317..00000000000 --- a/invokeai/backend/model_management_OLD/model_probe.py +++ /dev/null @@ -1,664 +0,0 @@ -import json -import re -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Dict, Literal, Optional, Union - -import safetensors.torch -import torch -from diffusers import ConfigMixin, ModelMixin -from picklescan.scanner import scan_file_path - -from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat - -from .models import ( - BaseModelType, - InvalidModelException, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SilenceWarnings, -) -from .models.base import read_checkpoint_meta -from .util import lora_token_vector_length - - -@dataclass -class ModelProbeInfo(object): - model_type: ModelType - base_type: BaseModelType - variant_type: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool - format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"] - image_size: int - name: Optional[str] = None - description: Optional[str] = None - - -class ProbeBase(object): - """forward declaration""" - - pass - - -class ModelProbe(object): - PROBES = { - "diffusers": {}, - "checkpoint": {}, - "onnx": {}, - } - - CLASS2TYPE = { - "StableDiffusionPipeline": ModelType.Main, - "StableDiffusionInpaintPipeline": ModelType.Main, - "StableDiffusionXLPipeline": ModelType.Main, - "StableDiffusionXLImg2ImgPipeline": ModelType.Main, - "StableDiffusionXLInpaintPipeline": ModelType.Main, - "LatentConsistencyModelPipeline": ModelType.Main, - "AutoencoderKL": ModelType.Vae, - "AutoencoderTiny": ModelType.Vae, - "ControlNetModel": ModelType.ControlNet, - "CLIPVisionModelWithProjection": ModelType.CLIPVision, - "T2IAdapter": ModelType.T2IAdapter, - } - - @classmethod - def register_probe( - cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase - ): - cls.PROBES[format][model_type] = probe_class - - @classmethod - def heuristic_probe( - cls, - model: Union[Dict, ModelMixin, Path], - prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, - ) -> ModelProbeInfo: - if isinstance(model, Path): - return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper) - elif isinstance(model, (dict, ModelMixin, ConfigMixin)): - return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper) - else: - raise InvalidModelException("model parameter {model} is neither a Path, nor a model") - - @classmethod - def probe( - cls, - model_path: Path, - model: Optional[Union[Dict, ModelMixin]] = None, - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> ModelProbeInfo: - """ - Probe the model at model_path and return sufficient information about it - to place it somewhere in the models directory hierarchy. If the model is - already loaded into memory, you may provide it as model in order to avoid - opening it a second time. The prediction_type_helper callable is a function that receives - the path to the model and returns the SchedulerPredictionType. - """ - if model_path: - format_type = "diffusers" if model_path.is_dir() else "checkpoint" - else: - format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint" - model_info = None - try: - model_type = ( - cls.get_model_type_from_folder(model_path, model) - if format_type == "diffusers" - else cls.get_model_type_from_checkpoint(model_path, model) - ) - format_type = "onnx" if model_type == ModelType.ONNX else format_type - probe_class = cls.PROBES[format_type].get(model_type) - if not probe_class: - return None - probe = probe_class(model_path, model, prediction_type_helper) - base_type = probe.get_base_type() - variant_type = probe.get_variant_type() - prediction_type = probe.get_scheduler_prediction_type() - name = cls.get_model_name(model_path) - description = f"{base_type.value} {model_type.value} model {name}" - format = probe.get_format() - model_info = ModelProbeInfo( - model_type=model_type, - base_type=base_type, - variant_type=variant_type, - prediction_type=prediction_type, - name=name, - description=description, - upcast_attention=( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ), - format=format, - image_size=( - 1024 - if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) - else ( - 768 - if ( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ) - else 512 - ) - ), - ) - except Exception: - raise - - return model_info - - @classmethod - def get_model_name(cls, model_path: Path) -> str: - if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: - return model_path.stem - else: - return model_path.name - - @classmethod - def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType: - if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): - return None - - if model_path.name == "learned_embeds.bin": - return ModelType.TextualInversion - - ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) - ckpt = ckpt.get("state_dict", ckpt) - - for key in ckpt.keys(): - if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): - return ModelType.Main - elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): - return ModelType.Vae - elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): - return ModelType.Lora - elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): - return ModelType.Lora - elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): - return ModelType.ControlNet - elif key in {"emb_params", "string_to_param"}: - return ModelType.TextualInversion - - else: - # diffusers-ti - if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): - return ModelType.TextualInversion - - raise InvalidModelException(f"Unable to determine model type for {model_path}") - - @classmethod - def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType: - """ - Get the model type of a hugging-face style folder. - """ - class_name = None - error_hint = None - if model: - class_name = model.__class__.__name__ - else: - for suffix in ["bin", "safetensors"]: - if (folder_path / f"learned_embeds.{suffix}").exists(): - return ModelType.TextualInversion - if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): - return ModelType.Lora - if (folder_path / "unet/model.onnx").exists(): - return ModelType.ONNX - if (folder_path / "image_encoder.txt").exists(): - return ModelType.IPAdapter - - i = folder_path / "model_index.json" - c = folder_path / "config.json" - config_path = i if i.exists() else c if c.exists() else None - - if config_path: - with open(config_path, "r") as file: - conf = json.load(file) - if "_class_name" in conf: - class_name = conf["_class_name"] - elif "architectures" in conf: - class_name = conf["architectures"][0] - else: - class_name = None - else: - error_hint = f"No model_index.json or config.json found in {folder_path}." - - if class_name and (type := cls.CLASS2TYPE.get(class_name)): - return type - else: - error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" - - # give up - raise InvalidModelException( - f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") - ) - - @classmethod - def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: - with SilenceWarnings(): - if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): - cls._scan_model(model_path, model_path) - return torch.load(model_path, map_location="cpu") - else: - return safetensors.torch.load_file(model_path) - - @classmethod - def _scan_model(cls, model_name, checkpoint): - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - raise Exception("The model {model_name} is potentially infected by malware. Aborting import.") - - -# ##################################################3 -# Checkpoint probing -# ##################################################3 -class ProbeBase(object): - def get_base_type(self) -> BaseModelType: - pass - - def get_variant_type(self) -> ModelVariantType: - pass - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - pass - - def get_format(self) -> str: - pass - - -class CheckpointProbeBase(ProbeBase): - def __init__( - self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None - ) -> BaseModelType: - self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) - self.checkpoint_path = checkpoint_path - self.helper = helper - - def get_base_type(self) -> BaseModelType: - pass - - def get_format(self) -> str: - return "checkpoint" - - def get_variant_type(self) -> ModelVariantType: - model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint) - if model_type != ModelType.Main: - return ModelVariantType.Normal - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - else: - raise InvalidModelException( - f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}" - ) - - -class PipelineCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: - return BaseModelType.StableDiffusionXL - elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: - return BaseModelType.StableDiffusionXLRefiner - else: - raise InvalidModelException("Cannot determine base type") - - def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: - """Return model prediction type.""" - # if there is a .yaml associated with this checkpoint, then we do not need - # to probe for the prediction type as it will be ignored. - if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists(): - return None - - type = self.get_base_type() - if type == BaseModelType.StableDiffusion2: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if "global_step" in checkpoint: - if checkpoint["global_step"] == 220000: - return SchedulerPredictionType.Epsilon - elif checkpoint["global_step"] == 110000: - return SchedulerPredictionType.VPrediction - if self.helper and self.checkpoint_path: - if helper_guess := self.helper(self.checkpoint_path): - return helper_guess - return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts - - elif type == BaseModelType.StableDiffusion1: - if self.helper and self.checkpoint_path: - if helper_guess := self.helper(self.checkpoint_path): - return helper_guess - return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts - else: - return None - - -class VaeCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - # I can't find any standalone 2.X VAEs to test with! - return BaseModelType.StableDiffusion1 - - -class LoRACheckpointProbe(CheckpointProbeBase): - def get_format(self) -> str: - return "lycoris" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - token_vector_length = lora_token_vector_length(checkpoint) - - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}") - - -class TextualInversionCheckpointProbe(CheckpointProbeBase): - def get_format(self) -> str: - return None - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if "string_to_token" in checkpoint: - token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] - elif "emb_params" in checkpoint: - token_dim = checkpoint["emb_params"].shape[-1] - elif "clip_g" in checkpoint: - token_dim = checkpoint["clip_g"].shape[-1] - else: - token_dim = list(checkpoint.values())[0].shape[-1] - if token_dim == 768: - return BaseModelType.StableDiffusion1 - elif token_dim == 1024: - return BaseModelType.StableDiffusion2 - elif token_dim == 1280: - return BaseModelType.StableDiffusionXL - else: - return None - - -class ControlNetCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - for key_name in ( - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - ): - if key_name not in checkpoint: - continue - if checkpoint[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - elif checkpoint[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - elif self.checkpoint_path and self.helper: - return self.helper(self.checkpoint_path) - raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") - - -class IPAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class CLIPVisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class T2IAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -######################################################## -# classes for probing folders -####################################################### -class FolderProbeBase(ProbeBase): - def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used - self.model = model - self.folder_path = folder_path - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - def get_format(self) -> str: - return "diffusers" - - -class PipelineFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self.model: - unet_conf = self.model.unet.config - else: - with open(self.folder_path / "unet" / "config.json", "r") as file: - unet_conf = json.load(file) - if unet_conf["cross_attention_dim"] == 768: - return BaseModelType.StableDiffusion1 - elif unet_conf["cross_attention_dim"] == 1024: - return BaseModelType.StableDiffusion2 - elif unet_conf["cross_attention_dim"] == 1280: - return BaseModelType.StableDiffusionXLRefiner - elif unet_conf["cross_attention_dim"] == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"Unknown base model for {self.folder_path}") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - if self.model: - scheduler_conf = self.model.scheduler.config - else: - with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file: - scheduler_conf = json.load(file) - if scheduler_conf["prediction_type"] == "v_prediction": - return SchedulerPredictionType.VPrediction - elif scheduler_conf["prediction_type"] == "epsilon": - return SchedulerPredictionType.Epsilon - else: - return None - - def get_variant_type(self) -> ModelVariantType: - # This only works for pipelines! Any kind of - # exception results in our returning the - # "normal" variant type - try: - if self.model: - conf = self.model.unet.config - else: - config_file = self.folder_path / "unet" / "config.json" - with open(config_file, "r") as file: - conf = json.load(file) - - in_channels = conf["in_channels"] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - except Exception: - pass - return ModelVariantType.Normal - - -class VaeFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self._config_looks_like_sdxl(): - return BaseModelType.StableDiffusionXL - elif self._name_looks_like_sdxl(): - # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down - # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. - return BaseModelType.StableDiffusionXL - else: - return BaseModelType.StableDiffusion1 - - def _config_looks_like_sdxl(self) -> bool: - # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - - def _name_looks_like_sdxl(self) -> bool: - return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) - - def _guess_name(self) -> str: - name = self.folder_path.name - if name == "vae": - name = self.folder_path.parent.name - return name - - -class TextualInversionFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return None - - def get_base_type(self) -> BaseModelType: - path = self.folder_path / "learned_embeds.bin" - if not path.exists(): - return None - checkpoint = ModelProbe._scan_and_load_checkpoint(path) - return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type() - - -class ONNXFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return "onnx" - - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - -class ControlNetFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - # no obvious way to distinguish between sd2-base and sd2-768 - dimension = config["cross_attention_dim"] - base_model = ( - BaseModelType.StableDiffusion1 - if dimension == 768 - else ( - BaseModelType.StableDiffusion2 - if dimension == 1024 - else BaseModelType.StableDiffusionXL - if dimension == 2048 - else None - ) - ) - if not base_model: - raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") - return base_model - - -class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - model_file = None - for suffix in ["safetensors", "bin"]: - base_file = self.folder_path / f"pytorch_lora_weights.{suffix}" - if base_file.exists(): - model_file = base_file - break - if not model_file: - raise InvalidModelException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file, None).get_base_type() - - -class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return IPAdapterModelFormat.InvokeAI.value - - def get_base_type(self) -> BaseModelType: - model_file = self.folder_path / "ip_adapter.bin" - if not model_file.exists(): - raise InvalidModelException("Unknown IP-Adapter model format.") - - state_dict = torch.load(model_file, map_location="cpu") - cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.") - - -class CLIPVisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class T2IAdapterFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - - adapter_type = config.get("adapter_type", None) - if adapter_type == "full_adapter_xl": - return BaseModelType.StableDiffusionXL - elif adapter_type == "full_adapter" or "light_adapter": - # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter. - return BaseModelType.StableDiffusion1 - else: - raise InvalidModelException( - f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})." - ) - - -############## register probe classes ###### -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) - -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) - -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_management_OLD/model_search.py b/invokeai/backend/model_management_OLD/model_search.py deleted file mode 100644 index e125c3ced7f..00000000000 --- a/invokeai/backend/model_management_OLD/model_search.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2023, Lincoln D. Stein and the InvokeAI Team -""" -Abstract base class for recursive directory search for models. -""" - -import os -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Set, types - -import invokeai.backend.util.logging as logger - - -class ModelSearch(ABC): - def __init__(self, directories: List[Path], logger: types.ModuleType = logger): - """ - Initialize a recursive model directory search. - :param directories: List of directory Paths to recurse through - :param logger: Logger to use - """ - self.directories = directories - self.logger = logger - self._items_scanned = 0 - self._models_found = 0 - self._scanned_dirs = set() - self._scanned_paths = set() - self._pruned_paths = set() - - @abstractmethod - def on_search_started(self): - """ - Called before the scan starts. - """ - pass - - @abstractmethod - def on_model_found(self, model: Path): - """ - Process a found model. Raise an exception if something goes wrong. - :param model: Model to process - could be a directory or checkpoint. - """ - pass - - @abstractmethod - def on_search_completed(self): - """ - Perform some activity when the scan is completed. May use instance - variables, items_scanned and models_found - """ - pass - - def search(self): - self.on_search_started() - for dir in self.directories: - self.walk_directory(dir) - self.on_search_completed() - - def walk_directory(self, path: Path): - for root, dirs, files in os.walk(path, followlinks=True): - if str(Path(root).name).startswith("."): - self._pruned_paths.add(root) - if any(Path(root).is_relative_to(x) for x in self._pruned_paths): - continue - - self._items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in self._scanned_paths or path.parent in self._scanned_dirs: - self._scanned_dirs.add(path) - continue - if any( - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "image_encoder.txt", - } - ): - try: - self.on_model_found(path) - self._models_found += 1 - self._scanned_dirs.add(path) - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - for f in files: - path = Path(root) / f - if path.parent in self._scanned_dirs: - continue - if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: - try: - self.on_model_found(path) - self._models_found += 1 - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - -class FindModels(ModelSearch): - def on_search_started(self): - self.models_found: Set[Path] = set() - - def on_model_found(self, model: Path): - self.models_found.add(model) - - def on_search_completed(self): - pass - - def list_models(self) -> List[Path]: - self.search() - return list(self.models_found) diff --git a/invokeai/backend/model_management_OLD/models/__init__.py b/invokeai/backend/model_management_OLD/models/__init__.py deleted file mode 100644 index 5f9b13b96f1..00000000000 --- a/invokeai/backend/model_management_OLD/models/__init__.py +++ /dev/null @@ -1,167 +0,0 @@ -import inspect -from enum import Enum -from typing import Literal, get_origin - -from pydantic import BaseModel, ConfigDict, create_model - -from .base import ( # noqa: F401 - BaseModelType, - DuplicateModelException, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelError, - ModelNotFoundException, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SilenceWarnings, - SubModelType, -) -from .clip_vision import CLIPVisionModel -from .controlnet import ControlNetModel # TODO: -from .ip_adapter import IPAdapterModel -from .lora import LoRAModel -from .sdxl import StableDiffusionXLModel -from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model -from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model -from .t2i_adapter import T2IAdapterModel -from .textual_inversion import TextualInversionModel -from .vae import VaeModel - -MODEL_CLASSES = { - BaseModelType.StableDiffusion1: { - ModelType.ONNX: ONNXStableDiffusion1Model, - ModelType.Main: StableDiffusion1Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusion2: { - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.Main: StableDiffusion2Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusionXL: { - ModelType.Main: StableDiffusionXLModel, - ModelType.Vae: VaeModel, - # will not work until support written - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelType.Main: StableDiffusionXLModel, - ModelType.Vae: VaeModel, - # will not work until support written - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.Any: { - ModelType.CLIPVision: CLIPVisionModel, - # The following model types are not expected to be used with BaseModelType.Any. - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.Main: StableDiffusion2Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - # BaseModelType.Kandinsky2_1: { - # ModelType.Main: Kandinsky2_1Model, - # ModelType.MoVQ: MoVQModel, - # ModelType.Lora: LoRAModel, - # ModelType.ControlNet: ControlNetModel, - # ModelType.TextualInversion: TextualInversionModel, - # }, -} - -MODEL_CONFIGS = [] -OPENAPI_MODEL_CONFIGS = [] - - -class OpenAPIModelInfoBase(BaseModel): - model_name: str - base_model: BaseModelType - model_type: ModelType - - model_config = ConfigDict(protected_namespaces=()) - - -for _base_model, models in MODEL_CLASSES.items(): - for model_type, model_class in models.items(): - model_configs = set(model_class._get_configs().values()) - model_configs.discard(None) - MODEL_CONFIGS.extend(model_configs) - - # LS: sort to get the checkpoint configs first, which makes - # for a better template in the Swagger docs - for cfg in sorted(model_configs, key=lambda x: str(x)): - model_name, cfg_name = cfg.__qualname__.split(".")[-2:] - openapi_cfg_name = model_name + cfg_name - if openapi_cfg_name in vars(): - continue - - api_wrapper = create_model( - openapi_cfg_name, - __base__=(cfg, OpenAPIModelInfoBase), - model_type=(Literal[model_type], model_type), # type: ignore - ) - vars()[openapi_cfg_name] = api_wrapper - OPENAPI_MODEL_CONFIGS.append(api_wrapper) - - -def get_model_config_enums(): - enums = [] - - for model_config in MODEL_CONFIGS: - if hasattr(inspect, "get_annotations"): - fields = inspect.get_annotations(model_config) - else: - fields = model_config.__annotations__ - try: - field = fields["model_format"] - except Exception: - raise Exception("format field not found") - - # model_format: None - # model_format: SomeModelFormat - # model_format: Literal[SomeModelFormat.Diffusers] - # model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint] - - if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): - enums.append(field) - - elif get_origin(field) is Literal and all( - isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__ - ): - enums.append(type(field.__args__[0])) - - elif field is None: - pass - - else: - raise Exception(f"Unsupported format definition in {model_configs.__qualname__}") - - return enums diff --git a/invokeai/backend/model_management_OLD/models/base.py b/invokeai/backend/model_management_OLD/models/base.py deleted file mode 100644 index 7807cb9a542..00000000000 --- a/invokeai/backend/model_management_OLD/models/base.py +++ /dev/null @@ -1,681 +0,0 @@ -import inspect -import json -import os -import sys -import typing -import warnings -from abc import ABCMeta, abstractmethod -from contextlib import suppress -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union - -import numpy as np -import onnx -import safetensors.torch -import torch -from diffusers import ConfigMixin, DiffusionPipeline -from diffusers import logging as diffusers_logging -from onnx import numpy_helper -from onnxruntime import InferenceSession, SessionOptions, get_available_providers -from picklescan.scanner import scan_file_path -from pydantic import BaseModel, ConfigDict, Field -from transformers import logging as transformers_logging - - -class DuplicateModelException(Exception): - pass - - -class InvalidModelException(Exception): - pass - - -class ModelNotFoundException(Exception): - pass - - -class BaseModelType(str, Enum): - Any = "any" # For models that are not associated with any particular base model. - StableDiffusion1 = "sd-1" - StableDiffusion2 = "sd-2" - StableDiffusionXL = "sdxl" - StableDiffusionXLRefiner = "sdxl-refiner" - # Kandinsky2_1 = "kandinsky-2.1" - - -class ModelType(str, Enum): - ONNX = "onnx" - Main = "main" - Vae = "vae" - Lora = "lora" - ControlNet = "controlnet" # used by model_probe - TextualInversion = "embedding" - IPAdapter = "ip_adapter" - CLIPVision = "clip_vision" - T2IAdapter = "t2i_adapter" - - -class SubModelType(str, Enum): - UNet = "unet" - TextEncoder = "text_encoder" - TextEncoder2 = "text_encoder_2" - Tokenizer = "tokenizer" - Tokenizer2 = "tokenizer_2" - Vae = "vae" - VaeDecoder = "vae_decoder" - VaeEncoder = "vae_encoder" - Scheduler = "scheduler" - SafetyChecker = "safety_checker" - # MoVQ = "movq" - - -class ModelVariantType(str, Enum): - Normal = "normal" - Inpaint = "inpaint" - Depth = "depth" - - -class SchedulerPredictionType(str, Enum): - Epsilon = "epsilon" - VPrediction = "v_prediction" - Sample = "sample" - - -class ModelError(str, Enum): - NotFound = "not_found" - - -def model_config_json_schema_extra(schema: dict[str, Any]) -> None: - if "required" not in schema: - schema["required"] = [] - schema["required"].append("model_type") - - -class ModelConfigBase(BaseModel): - path: str # or Path - description: Optional[str] = Field(None) - model_format: Optional[str] = Field(None) - error: Optional[ModelError] = Field(None) - - model_config = ConfigDict( - use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra - ) - - -class EmptyConfigLoader(ConfigMixin): - @classmethod - def load_config(cls, *args, **kwargs): - cls.config_name = kwargs.pop("config_name") - return super().load_config(*args, **kwargs) - - -T_co = TypeVar("T_co", covariant=True) - - -class classproperty(Generic[T_co]): - def __init__(self, fget: Callable[[Any], T_co]) -> None: - self.fget = fget - - def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co: - return self.fget(owner) - - def __set__(self, instance: Optional[Any], value: Any) -> None: - raise AttributeError("cannot set attribute") - - -class ModelBase(metaclass=ABCMeta): - # model_path: str - # base_model: BaseModelType - # model_type: ModelType - - def __init__( - self, - model_path: str, - base_model: BaseModelType, - model_type: ModelType, - ): - self.model_path = model_path - self.base_model = base_model - self.model_type = model_type - - def _hf_definition_to_type(self, subtypes: List[str]) -> Type: - if len(subtypes) < 2: - raise Exception("Invalid subfolder definition!") - if all(t is None for t in subtypes): - return None - elif any(t is None for t in subtypes): - raise Exception(f"Unsupported definition: {subtypes}") - - if subtypes[0] in ["diffusers", "transformers"]: - res_type = sys.modules[subtypes[0]] - subtypes = subtypes[1:] - - else: - res_type = sys.modules["diffusers"] - res_type = res_type.pipelines - - for subtype in subtypes: - res_type = getattr(res_type, subtype) - return res_type - - @classmethod - def _get_configs(cls): - with suppress(Exception): - return cls.__configs - - configs = {} - for name in dir(cls): - if name.startswith("__"): - continue - - value = getattr(cls, name) - if not isinstance(value, type) or not issubclass(value, ModelConfigBase): - continue - - if hasattr(inspect, "get_annotations"): - fields = inspect.get_annotations(value) - else: - fields = value.__annotations__ - try: - field = fields["model_format"] - except Exception: - raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})") - - if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): - for model_format in field: - configs[model_format.value] = value - - elif typing.get_origin(field) is Literal and all( - isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__ - ): - for model_format in field.__args__: - configs[model_format.value] = value - - elif field is None: - configs[None] = value - - else: - raise Exception(f"Unsupported format definition in {cls.__qualname__}") - - cls.__configs = configs - return cls.__configs - - @classmethod - def create_config(cls, **kwargs) -> ModelConfigBase: - if "model_format" not in kwargs: - raise Exception("Field 'model_format' not found in model config") - - configs = cls._get_configs() - return configs[kwargs["model_format"]](**kwargs) - - @classmethod - def probe_config(cls, path: str, **kwargs) -> ModelConfigBase: - return cls.create_config( - path=path, - model_format=cls.detect_format(path), - ) - - @classmethod - @abstractmethod - def detect_format(cls, path: str) -> str: - raise NotImplementedError() - - @classproperty - @abstractmethod - def save_to_config(cls) -> bool: - raise NotImplementedError() - - @abstractmethod - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - raise NotImplementedError() - - @abstractmethod - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> Any: - raise NotImplementedError() - - -class DiffusersModel(ModelBase): - # child_types: Dict[str, Type] - # child_sizes: Dict[str, int] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - super().__init__(model_path, base_model, model_type) - - self.child_types: Dict[str, Type] = {} - self.child_sizes: Dict[str, int] = {} - - try: - config_data = DiffusionPipeline.load_config(self.model_path) - # config_data = json.loads(os.path.join(self.model_path, "model_index.json")) - except Exception: - raise Exception("Invalid diffusers model! (model_index.json not found or invalid)") - - config_data.pop("_ignore_files", None) - - # retrieve all folder_names that contain relevant files - child_components = [k for k, v in config_data.items() if isinstance(v, list)] - - for child_name in child_components: - child_type = self._hf_definition_to_type(config_data[child_name]) - self.child_types[child_name] = child_type - self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is None: - return sum(self.child_sizes.values()) - else: - return self.child_sizes[child_type] - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - # return pipeline in different function to pass more arguments - if child_type is None: - raise Exception("Child model type can't be null on diffusers model") - if child_type not in self.child_types: - return None # TODO: or raise - - if torch_dtype == torch.float16: - variants = ["fp16", None] - else: - variants = [None, "fp16"] - - # TODO: better error handling(differentiate not found from others) - for variant in variants: - try: - # TODO: set cache_dir to /dev/null to be sure that cache not used? - model = self.child_types[child_type].from_pretrained( - self.model_path, - subfolder=child_type.value, - torch_dtype=torch_dtype, - variant=variant, - local_files_only=True, - ) - break - except Exception as e: - if not str(e).startswith("Error no file"): - print("====ERR LOAD====") - print(f"{variant}: {e}") - pass - else: - raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") - - # calc more accurate size - self.child_sizes[child_type] = calc_model_size_by_data(model) - return model - - # def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str: - - -def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None): - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - - # this can happen when, for example, the safety checker - # is not downloaded. - if not os.path.exists(model_path): - return 0 - - all_files = os.listdir(model_path) - all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] - - fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f} - bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f} - other_files = set(all_files) - fp16_files - bit8_files - - if variant is None: - files = other_files - elif variant == "fp16": - files = fp16_files - elif variant == "8bit": - files = bit8_files - else: - raise NotImplementedError(f"Unknown variant: {variant}") - - # try read from index if exists - index_postfix = ".index.json" - if variant is not None: - index_postfix = f".index.{variant}.json" - - for file in files: - if not file.endswith(index_postfix): - continue - try: - with open(os.path.join(model_path, file), "r") as f: - index_data = json.loads(f.read()) - return int(index_data["metadata"]["total_size"]) - except Exception: - pass - - # calculate files size if there is no index file - formats = [ - (".safetensors",), # safetensors - (".bin",), # torch - (".onnx", ".pb"), # onnx - (".msgpack",), # flax - (".ckpt",), # tf - (".h5",), # tf2 - ] - - for file_format in formats: - model_files = [f for f in files if f.endswith(file_format)] - if len(model_files) == 0: - continue - - model_size = 0 - for model_file in model_files: - file_stats = os.stat(os.path.join(model_path, model_file)) - model_size += file_stats.st_size - return model_size - - # raise NotImplementedError(f"Unknown model structure! Files: {all_files}") - return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu - - -def calc_model_size_by_data(model) -> int: - if isinstance(model, DiffusionPipeline): - return _calc_pipeline_by_data(model) - elif isinstance(model, torch.nn.Module): - return _calc_model_by_data(model) - elif isinstance(model, IAIOnnxRuntimeModel): - return _calc_onnx_model_by_data(model) - else: - return 0 - - -def _calc_pipeline_by_data(pipeline) -> int: - res = 0 - for submodel_key in pipeline.components.keys(): - submodel = getattr(pipeline, submodel_key) - if submodel is not None and isinstance(submodel, torch.nn.Module): - res += _calc_model_by_data(submodel) - return res - - -def _calc_model_by_data(model) -> int: - mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) - mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) - mem = mem_params + mem_bufs # in bytes - return mem - - -def _calc_onnx_model_by_data(model) -> int: - tensor_size = model.tensors.size() * 2 # The session doubles this - mem = tensor_size # in bytes - return mem - - -def _fast_safetensors_reader(path: str): - checkpoint = {} - device = torch.device("meta") - with open(path, "rb") as f: - definition_len = int.from_bytes(f.read(8), "little") - definition_json = f.read(definition_len) - definition = json.loads(definition_json) - - if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { - "pt", - "torch", - "pytorch", - }: - raise Exception("Supported only pytorch safetensors files") - definition.pop("__metadata__", None) - - for key, info in definition.items(): - dtype = { - "I8": torch.int8, - "I16": torch.int16, - "I32": torch.int32, - "I64": torch.int64, - "F16": torch.float16, - "F32": torch.float32, - "F64": torch.float64, - }[info["dtype"]] - - checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) - - return checkpoint - - -def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): - if str(path).endswith(".safetensors"): - try: - checkpoint = _fast_safetensors_reader(path) - except Exception: - # TODO: create issue for support "meta"? - checkpoint = safetensors.torch.load_file(path, device="cpu") - else: - if scan: - scan_result = scan_file_path(path) - if scan_result.infected_files != 0: - raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') - checkpoint = torch.load(path, map_location=torch.device("meta")) - return checkpoint - - -class SilenceWarnings(object): - def __init__(self): - self.transformers_verbosity = transformers_logging.get_verbosity() - self.diffusers_verbosity = diffusers_logging.get_verbosity() - - def __enter__(self): - transformers_logging.set_verbosity_error() - diffusers_logging.set_verbosity_error() - warnings.simplefilter("ignore") - - def __exit__(self, type, value, traceback): - transformers_logging.set_verbosity(self.transformers_verbosity) - diffusers_logging.set_verbosity(self.diffusers_verbosity) - warnings.simplefilter("default") - - -ONNX_WEIGHTS_NAME = "model.onnx" - - -class IAIOnnxRuntimeModel: - class _tensor_access: - def __init__(self, model): - self.model = model - self.indexes = {} - for idx, obj in enumerate(self.model.proto.graph.initializer): - self.indexes[obj.name] = idx - - def __getitem__(self, key: str): - value = self.model.proto.graph.initializer[self.indexes[key]] - return numpy_helper.to_array(value) - - def __setitem__(self, key: str, value: np.ndarray): - new_node = numpy_helper.from_array(value) - # set_external_data(new_node, location="in-memory-location") - new_node.name = key - # new_node.ClearField("raw_data") - del self.model.proto.graph.initializer[self.indexes[key]] - self.model.proto.graph.initializer.insert(self.indexes[key], new_node) - # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) - - # __delitem__ - - def __contains__(self, key: str): - return self.indexes[key] in self.model.proto.graph.initializer - - def items(self): - raise NotImplementedError("tensor.items") - # return [(obj.name, obj) for obj in self.raw_proto] - - def keys(self): - return self.indexes.keys() - - def values(self): - raise NotImplementedError("tensor.values") - # return [obj for obj in self.raw_proto] - - def size(self): - bytesSum = 0 - for node in self.model.proto.graph.initializer: - bytesSum += sys.getsizeof(node.raw_data) - return bytesSum - - class _access_helper: - def __init__(self, raw_proto): - self.indexes = {} - self.raw_proto = raw_proto - for idx, obj in enumerate(raw_proto): - self.indexes[obj.name] = idx - - def __getitem__(self, key: str): - return self.raw_proto[self.indexes[key]] - - def __setitem__(self, key: str, value): - index = self.indexes[key] - del self.raw_proto[index] - self.raw_proto.insert(index, value) - - # __delitem__ - - def __contains__(self, key: str): - return key in self.indexes - - def items(self): - return [(obj.name, obj) for obj in self.raw_proto] - - def keys(self): - return self.indexes.keys() - - def values(self): - return list(self.raw_proto) - - def __init__(self, model_path: str, provider: Optional[str]): - self.path = model_path - self.session = None - self.provider = provider - """ - self.data_path = self.path + "_data" - if not os.path.exists(self.data_path): - print(f"Moving model tensors to separate file: {self.data_path}") - tmp_proto = onnx.load(model_path, load_external_data=True) - onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) - del tmp_proto - gc.collect() - - self.proto = onnx.load(model_path, load_external_data=False) - """ - - self.proto = onnx.load(model_path, load_external_data=True) - # self.data = dict() - # for tensor in self.proto.graph.initializer: - # name = tensor.name - - # if tensor.HasField("raw_data"): - # npt = numpy_helper.to_array(tensor) - # orv = OrtValue.ortvalue_from_numpy(npt) - # # self.data[name] = orv - # # set_external_data(tensor, location="in-memory-location") - # tensor.name = name - # # tensor.ClearField("raw_data") - - self.nodes = self._access_helper(self.proto.graph.node) - # self.initializers = self._access_helper(self.proto.graph.initializer) - # print(self.proto.graph.input) - # print(self.proto.graph.initializer) - - self.tensors = self._tensor_access(self) - - # TODO: integrate with model manager/cache - def create_session(self, height=None, width=None): - if self.session is None or self.session_width != width or self.session_height != height: - # onnx.save(self.proto, "tmp.onnx") - # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) - # TODO: something to be able to get weight when they already moved outside of model proto - # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) - sess = SessionOptions() - # self._external_data.update(**external_data) - # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) - # sess.enable_profiling = True - - # sess.intra_op_num_threads = 1 - # sess.inter_op_num_threads = 1 - # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL - # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - # sess.enable_cpu_mem_arena = True - # sess.enable_mem_pattern = True - # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code - self.session_height = height - self.session_width = width - if height and width: - sess.add_free_dimension_override_by_name("unet_sample_batch", 2) - sess.add_free_dimension_override_by_name("unet_sample_channels", 4) - sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) - sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) - sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) - sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) - sess.add_free_dimension_override_by_name("unet_time_batch", 1) - providers = [] - if self.provider: - providers.append(self.provider) - else: - providers = get_available_providers() - if "TensorrtExecutionProvider" in providers: - providers.remove("TensorrtExecutionProvider") - try: - self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) - except Exception as e: - raise e - # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) - # self.io_binding = self.session.io_binding() - - def release_session(self): - self.session = None - import gc - - gc.collect() - return - - def __call__(self, **kwargs): - if self.session is None: - raise Exception("You should call create_session before running model") - - inputs = {k: np.array(v) for k, v in kwargs.items()} - # output_names = self.session.get_outputs() - # for k in inputs: - # self.io_binding.bind_cpu_input(k, inputs[k]) - # for name in output_names: - # self.io_binding.bind_output(name.name) - # self.session.run_with_iobinding(self.io_binding, None) - # return self.io_binding.copy_outputs_to_cpu() - return self.session.run(None, inputs) - - # compatability with diffusers load code - @classmethod - def from_pretrained( - cls, - model_id: Union[str, Path], - subfolder: Union[str, Path] = None, - file_name: Optional[str] = None, - provider: Optional[str] = None, - sess_options: Optional["SessionOptions"] = None, - **kwargs, - ): - file_name = file_name or ONNX_WEIGHTS_NAME - - if os.path.isdir(model_id): - model_path = model_id - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - model_path = os.path.join(model_path, file_name) - - else: - model_path = model_id - - # load model from local directory - if not os.path.isfile(model_path): - raise Exception(f"Model not found: {model_path}") - - # TODO: session options - return cls(model_path, provider=provider) diff --git a/invokeai/backend/model_management_OLD/models/clip_vision.py b/invokeai/backend/model_management_OLD/models/clip_vision.py deleted file mode 100644 index 2276c6beed1..00000000000 --- a/invokeai/backend/model_management_OLD/models/clip_vision.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -from enum import Enum -from typing import Literal, Optional - -import torch -from transformers import CLIPVisionModelWithProjection - -from invokeai.backend.model_management.models.base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class CLIPVisionModelFormat(str, Enum): - Diffusers = "diffusers" - - -class CLIPVisionModel(ModelBase): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[CLIPVisionModelFormat.Diffusers] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.CLIPVision - super().__init__(model_path, base_model, model_type) - - self.model_size = calc_model_size_by_fs(self.model_path) - - @classmethod - def detect_format(cls, path: str) -> str: - if not os.path.exists(path): - raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.") - - if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")): - return CLIPVisionModelFormat.Diffusers - - raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}") - - @classproperty - def save_to_config(cls) -> bool: - return True - - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - if child_type is not None: - raise ValueError("There are no child models in a CLIP Vision model.") - - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> CLIPVisionModelWithProjection: - if child_type is not None: - raise ValueError("There are no child models in a CLIP Vision model.") - - model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype) - - # Calculate a more accurate model size. - self.model_size = calc_model_size_by_data(model) - - return model - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == CLIPVisionModelFormat.Diffusers: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") diff --git a/invokeai/backend/model_management_OLD/models/controlnet.py b/invokeai/backend/model_management_OLD/models/controlnet.py deleted file mode 100644 index 3b534cb9d14..00000000000 --- a/invokeai/backend/model_management_OLD/models/controlnet.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional - -import torch - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class ControlNetModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class ControlNetModel(ModelBase): - # model_class: Type - # model_size: int - - class DiffusersConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Diffusers] - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Checkpoint] - config: str - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.ControlNet - super().__init__(model_path, base_model, model_type) - - try: - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - # config = json.loads(os.path.join(self.model_path, "config.json")) - except Exception: - raise Exception("Invalid controlnet model! (config.json not found or invalid)") - - model_class_name = config.get("_class_name", None) - if model_class_name not in {"ControlNetModel"}: - raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}") - - try: - self.model_class = self._hf_definition_to_type(["diffusers", model_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - except Exception: - raise Exception("Invalid ControlNet model!") - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in controlnet model") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There are no child models in controlnet model") - - model = None - for variant in ["fp16", None]: - try: - model = self.model_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - variant=variant, - ) - break - except Exception: - pass - if not model: - raise ModelNotFoundException() - - # calc more accurate size - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return ControlNetModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]): - return ControlNetModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint: - return _convert_controlnet_ckpt_and_cache( - model_path=model_path, - model_config=config.config, - output_path=output_path, - base_model=base_model, - ) - else: - return model_path - - -def _convert_controlnet_ckpt_and_cache( - model_path: str, - output_path: str, - base_model: BaseModelType, - model_config: str, -) -> str: - """ - Convert the controlnet from checkpoint format to diffusers format, - cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - weights = app_config.root_path / model_path - output_path = Path(output_path) - - logger.info(f"Converting {weights} to diffusers format") - # return cached version if it exists - if output_path.exists(): - return output_path - - # to avoid circular import errors - from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers - - convert_controlnet_to_diffusers( - weights, - output_path, - original_config_file=app_config.root_path / model_config, - image_size=512, - scan_needed=True, - from_safetensors=weights.suffix == ".safetensors", - ) - return output_path diff --git a/invokeai/backend/model_management_OLD/models/ip_adapter.py b/invokeai/backend/model_management_OLD/models/ip_adapter.py deleted file mode 100644 index c60edd0abe3..00000000000 --- a/invokeai/backend/model_management_OLD/models/ip_adapter.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import typing -from enum import Enum -from typing import Literal, Optional - -import torch - -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter -from invokeai.backend.model_management.models.base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelType, - SubModelType, - calc_model_size_by_fs, - classproperty, -) - - -class IPAdapterModelFormat(str, Enum): - # The custom IP-Adapter model format defined by InvokeAI. - InvokeAI = "invokeai" - - -class IPAdapterModel(ModelBase): - class InvokeAIConfig(ModelConfigBase): - model_format: Literal[IPAdapterModelFormat.InvokeAI] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.IPAdapter - super().__init__(model_path, base_model, model_type) - - self.model_size = calc_model_size_by_fs(self.model_path) - - @classmethod - def detect_format(cls, path: str) -> str: - if not os.path.exists(path): - raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.") - - if os.path.isdir(path): - model_file = os.path.join(path, "ip_adapter.bin") - image_encoder_config_file = os.path.join(path, "image_encoder.txt") - if os.path.exists(model_file) and os.path.exists(image_encoder_config_file): - return IPAdapterModelFormat.InvokeAI - - raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}") - - @classproperty - def save_to_config(cls) -> bool: - return True - - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - if child_type is not None: - raise ValueError("There are no child models in an IP-Adapter model.") - - return self.model_size - - def get_model( - self, - torch_dtype: torch.dtype, - child_type: Optional[SubModelType] = None, - ) -> typing.Union[IPAdapter, IPAdapterPlus]: - if child_type is not None: - raise ValueError("There are no child models in an IP-Adapter model.") - - model = build_ip_adapter( - ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), - device=torch.device("cpu"), - dtype=torch_dtype, - ) - - self.model_size = model.calc_size() - return model - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == IPAdapterModelFormat.InvokeAI: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") - - -def get_ip_adapter_image_encoder_model_id(model_path: str): - """Read the ID of the image encoder associated with the IP-Adapter at `model_path`.""" - image_encoder_config_file = os.path.join(model_path, "image_encoder.txt") - - with open(image_encoder_config_file, "r") as f: - image_encoder_model = f.readline().strip() - - return image_encoder_model diff --git a/invokeai/backend/model_management_OLD/models/lora.py b/invokeai/backend/model_management_OLD/models/lora.py deleted file mode 100644 index b110d75d220..00000000000 --- a/invokeai/backend/model_management_OLD/models/lora.py +++ /dev/null @@ -1,696 +0,0 @@ -import bisect -import os -from enum import Enum -from pathlib import Path -from typing import Dict, Optional, Union - -import torch -from safetensors.torch import load_file - -from .base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - classproperty, -) - - -class LoRAModelFormat(str, Enum): - LyCORIS = "lycoris" - Diffusers = "diffusers" - - -class LoRAModel(ModelBase): - # model_size: int - - class Config(ModelConfigBase): - model_format: LoRAModelFormat # TODO: - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Lora - super().__init__(model_path, base_model, model_type) - - self.model_size = os.path.getsize(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in lora") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in lora") - - model = LoRAModelRaw.from_checkpoint( - file_path=self.model_path, - dtype=torch_dtype, - base_model=self.base_model, - ) - - self.model_size = model.calc_size() - return model - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - for ext in ["safetensors", "bin"]: - if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")): - return LoRAModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return LoRAModelFormat.LyCORIS - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == LoRAModelFormat.Diffusers: - for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder - path = Path(model_path, f"pytorch_lora_weights.{ext}") - if path.exists(): - return path - else: - return model_path - - -class LoRALayerBase: - # rank: Optional[int] - # alpha: Optional[float] - # bias: Optional[torch.Tensor] - # layer_key: str - - # @property - # def scale(self): - # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - - def __init__( - self, - layer_key: str, - values: dict, - ): - if "alpha" in values: - self.alpha = values["alpha"].item() - else: - self.alpha = None - - if "bias_indices" in values and "bias_values" in values and "bias_size" in values: - self.bias = torch.sparse_coo_tensor( - values["bias_indices"], - values["bias_values"], - tuple(values["bias_size"]), - ) - - else: - self.bias = None - - self.rank = None # set in layer implementation - self.layer_key = layer_key - - def get_weight(self, orig_weight: torch.Tensor): - raise NotImplementedError() - - def calc_size(self) -> int: - model_size = 0 - for val in [self.bias]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - if self.bias is not None: - self.bias = self.bias.to(device=device, dtype=dtype) - - -# TODO: find and debug lora/locon with bias -class LoRALayer(LoRALayerBase): - # up: torch.Tensor - # mid: Optional[torch.Tensor] - # down: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.up = values["lora_up.weight"] - self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid = values["lora_mid.weight"] - else: - self.mid = None - - self.rank = self.down.shape[0] - - def get_weight(self, orig_weight: torch.Tensor): - if self.mid is not None: - up = self.up.reshape(self.up.shape[0], self.up.shape[1]) - down = self.down.reshape(self.down.shape[0], self.down.shape[1]) - weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) - else: - weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.up, self.mid, self.down]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.up = self.up.to(device=device, dtype=dtype) - self.down = self.down.to(device=device, dtype=dtype) - - if self.mid is not None: - self.mid = self.mid.to(device=device, dtype=dtype) - - -class LoHALayer(LoRALayerBase): - # w1_a: torch.Tensor - # w1_b: torch.Tensor - # w2_a: torch.Tensor - # w2_b: torch.Tensor - # t1: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.w1_a = values["hada_w1_a"] - self.w1_b = values["hada_w1_b"] - self.w2_a = values["hada_w2_a"] - self.w2_b = values["hada_w2_b"] - - if "hada_t1" in values: - self.t1 = values["hada_t1"] - else: - self.t1 = None - - if "hada_t2" in values: - self.t2 = values["hada_t2"] - else: - self.t2 = None - - self.rank = self.w1_b.shape[0] - - def get_weight(self, orig_weight: torch.Tensor): - if self.t1 is None: - weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) - - else: - rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) - rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) - weight = rebuild1 * rebuild2 - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - if self.t1 is not None: - self.t1 = self.t1.to(device=device, dtype=dtype) - - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class LoKRLayer(LoRALayerBase): - # w1: Optional[torch.Tensor] = None - # w1_a: Optional[torch.Tensor] = None - # w1_b: Optional[torch.Tensor] = None - # w2: Optional[torch.Tensor] = None - # w2_a: Optional[torch.Tensor] = None - # w2_b: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - if "lokr_w1" in values: - self.w1 = values["lokr_w1"] - self.w1_a = None - self.w1_b = None - else: - self.w1 = None - self.w1_a = values["lokr_w1_a"] - self.w1_b = values["lokr_w1_b"] - - if "lokr_w2" in values: - self.w2 = values["lokr_w2"] - self.w2_a = None - self.w2_b = None - else: - self.w2 = None - self.w2_a = values["lokr_w2_a"] - self.w2_b = values["lokr_w2_b"] - - if "lokr_t2" in values: - self.t2 = values["lokr_t2"] - else: - self.t2 = None - - if "lokr_w1_b" in values: - self.rank = values["lokr_w1_b"].shape[0] - elif "lokr_w2_b" in values: - self.rank = values["lokr_w2_b"].shape[0] - else: - self.rank = None # unscaled - - def get_weight(self, orig_weight: torch.Tensor): - w1 = self.w1 - if w1 is None: - w1 = self.w1_a @ self.w1_b - - w2 = self.w2 - if w2 is None: - if self.t2 is None: - w2 = self.w2_a @ self.w2_b - else: - w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - w2 = w2.contiguous() - weight = torch.kron(w1, w2) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - if self.w1 is not None: - self.w1 = self.w1.to(device=device, dtype=dtype) - else: - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - - if self.w2 is not None: - self.w2 = self.w2.to(device=device, dtype=dtype) - else: - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class FullLayer(LoRALayerBase): - # weight: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.weight = values["diff"] - - if len(values.keys()) > 1: - _keys = list(values.keys()) - _keys.remove("diff") - raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") - - self.rank = None # unscaled - - def get_weight(self, orig_weight: torch.Tensor): - return self.weight - - def calc_size(self) -> int: - model_size = super().calc_size() - model_size += self.weight.nelement() * self.weight.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.weight = self.weight.to(device=device, dtype=dtype) - - -class IA3Layer(LoRALayerBase): - # weight: torch.Tensor - # on_input: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.weight = values["weight"] - self.on_input = values["on_input"] - - self.rank = None # unscaled - - def get_weight(self, orig_weight: torch.Tensor): - weight = self.weight - if not self.on_input: - weight = weight.reshape(-1, 1) - return orig_weight * weight - - def calc_size(self) -> int: - model_size = super().calc_size() - model_size += self.weight.nelement() * self.weight.element_size() - model_size += self.on_input.nelement() * self.on_input.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.weight = self.weight.to(device=device, dtype=dtype) - self.on_input = self.on_input.to(device=device, dtype=dtype) - - -# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix -class LoRAModelRaw: # (torch.nn.Module): - _name: str - layers: Dict[str, LoRALayer] - - def __init__( - self, - name: str, - layers: Dict[str, LoRALayer], - ): - self._name = name - self.layers = layers - - @property - def name(self): - return self._name - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - # TODO: try revert if exception? - for _key, layer in self.layers.items(): - layer.to(device=device, dtype=dtype) - - def calc_size(self) -> int: - model_size = 0 - for _, layer in self.layers.items(): - model_size += layer.calc_size() - return model_size - - @classmethod - def _convert_sdxl_keys_to_diffusers_format(cls, state_dict): - """Convert the keys of an SDXL LoRA state_dict to diffusers format. - - The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in - diffusers format, then this function will have no effect. - - This function is adapted from: - https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 - - Args: - state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. - - Raises: - ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. - - Returns: - Dict[str, Tensor]: The diffusers-format state_dict. - """ - converted_count = 0 # The number of Stability AI keys converted to diffusers format. - not_converted_count = 0 # The number of keys that were not converted. - - # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. - # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for - # `input_blocks_4_1_proj_in`. - stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) - stability_unet_keys.sort() - - new_state_dict = {} - for full_key, value in state_dict.items(): - if full_key.startswith("lora_unet_"): - search_key = full_key.replace("lora_unet_", "") - # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. - position = bisect.bisect_right(stability_unet_keys, search_key) - map_key = stability_unet_keys[position - 1] - # Now, check if the map_key *actually* matches the search_key. - if search_key.startswith(map_key): - new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) - new_state_dict[new_key] = value - converted_count += 1 - else: - new_state_dict[full_key] = value - not_converted_count += 1 - elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): - # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. - new_state_dict[full_key] = value - continue - else: - raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") - - if converted_count > 0 and not_converted_count > 0: - raise ValueError( - f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," - f" not_converted={not_converted_count}" - ) - - return new_state_dict - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - base_model: Optional[BaseModelType] = None, - ): - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - - if isinstance(file_path, str): - file_path = Path(file_path) - - model = cls( - name=file_path.stem, # TODO: - layers={}, - ) - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - state_dict = cls._group_state(state_dict) - - if base_model == BaseModelType.StableDiffusionXL: - state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) - - for layer_key, values in state_dict.items(): - # lora and locon - if "lora_down.weight" in values: - layer = LoRALayer(layer_key, values) - - # loha - elif "hada_w1_b" in values: - layer = LoHALayer(layer_key, values) - - # lokr - elif "lokr_w1_b" in values or "lokr_w1" in values: - layer = LoKRLayer(layer_key, values) - - # diff - elif "diff" in values: - layer = FullLayer(layer_key, values) - - # ia3 - elif "weight" in values and "on_input" in values: - layer = IA3Layer(layer_key, values) - - else: - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") - raise Exception("Unknown lora format!") - - # lower memory consumption by removing already parsed layer values - state_dict[layer_key].clear() - - layer.to(device=device, dtype=dtype) - model.layers[layer_key] = layer - - return model - - @staticmethod - def _group_state(state_dict: dict): - state_dict_groupped = {} - - for key, value in state_dict.items(): - stem, leaf = key.split(".", 1) - if stem not in state_dict_groupped: - state_dict_groupped[stem] = {} - state_dict_groupped[stem][leaf] = value - - return state_dict_groupped - - -# code from -# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 -def make_sdxl_unet_conversion_map(): - """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" - unet_conversion_map_layer = [] - - for i in range(3): # num_blocks is 3 in sdxl - # loop over downblocks/upblocks - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - # if i > 0: commentout for sdxl - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0.", "norm1."), - ("in_layers.2.", "conv1."), - ("out_layers.0.", "norm2."), - ("out_layers.3.", "conv2."), - ("emb_layers.1.", "time_emb_proj."), - ("skip_connection.", "conv_shortcut."), - ] - - unet_conversion_map = [] - for sd, hf in unet_conversion_map_layer: - if "resnets" in hf: - for sd_res, hf_res in unet_conversion_map_resnet: - unet_conversion_map.append((sd + sd_res, hf + hf_res)) - else: - unet_conversion_map.append((sd, hf)) - - for j in range(2): - hf_time_embed_prefix = f"time_embedding.linear_{j+1}." - sd_time_embed_prefix = f"time_embed.{j*2}." - unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) - - for j in range(2): - hf_label_embed_prefix = f"add_embedding.linear_{j+1}." - sd_label_embed_prefix = f"label_emb.0.{j*2}." - unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) - - unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) - unet_conversion_map.append(("out.0.", "conv_norm_out.")) - unet_conversion_map.append(("out.2.", "conv_out.")) - - return unet_conversion_map - - -SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { - sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() -} diff --git a/invokeai/backend/model_management_OLD/models/sdxl.py b/invokeai/backend/model_management_OLD/models/sdxl.py deleted file mode 100644 index 01e9420fed7..00000000000 --- a/invokeai/backend/model_management_OLD/models/sdxl.py +++ /dev/null @@ -1,148 +0,0 @@ -import json -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional - -from omegaconf import OmegaConf -from pydantic import Field - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae -from invokeai.backend.util.logging import InvokeAILogger - -from .base import ( - BaseModelType, - DiffusersModel, - InvalidModelException, - ModelConfigBase, - ModelType, - ModelVariantType, - classproperty, - read_checkpoint_meta, -) - - -class StableDiffusionXLModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusionXLModel(DiffusersModel): - # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusionXLModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusionXLModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner} - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusionXL, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusionXLModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusionXLModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model") - - else: - raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if ckpt_config_path is None: - # avoid circular import - from .stable_diffusion import _select_ckpt_config - - ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if os.path.isdir(model_path): - return StableDiffusionXLModelFormat.Diffusers - else: - return StableDiffusionXLModelFormat.Checkpoint - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - # The convert script adapted from the diffusers package uses - # strings for the base model type. To avoid making too many - # source code changes, we simply translate here - if Path(output_path).exists(): - return output_path - - if isinstance(config, cls.CheckpointConfig): - from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache - - # Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed, - # then we bake it into the converted model unless there is already - # a nonstandard VAE installed. - kwargs = {} - app_config = InvokeAIAppConfig.get_config() - vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix" - if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)): - InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.") - kwargs["vae_path"] = vae_path - - return _convert_ckpt_and_cache( - version=base_model, - model_config=config, - output_path=output_path, - use_safetensors=True, - **kwargs, - ) - else: - return model_path diff --git a/invokeai/backend/model_management_OLD/models/stable_diffusion.py b/invokeai/backend/model_management_OLD/models/stable_diffusion.py deleted file mode 100644 index a38a44fccf7..00000000000 --- a/invokeai/backend/model_management_OLD/models/stable_diffusion.py +++ /dev/null @@ -1,337 +0,0 @@ -import json -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional, Union - -from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline -from omegaconf import OmegaConf -from pydantic import Field - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - DiffusersModel, - InvalidModelException, - ModelConfigBase, - ModelNotFoundException, - ModelType, - ModelVariantType, - SilenceWarnings, - classproperty, - read_checkpoint_meta, -) -from .sdxl import StableDiffusionXLModel - - -class StableDiffusion1ModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusion1Model(DiffusersModel): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusion1ModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion1 - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusion1ModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusion1ModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format") - - else: - raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 1.* model format") - - if ckpt_config_path is None: - ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if not os.path.exists(model_path): - raise ModelNotFoundException() - - if os.path.isdir(model_path): - if os.path.exists(os.path.join(model_path, "model_index.json")): - return StableDiffusion1ModelFormat.Diffusers - - if os.path.isfile(model_path): - if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return StableDiffusion1ModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {model_path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - return _convert_ckpt_and_cache( - version=BaseModelType.StableDiffusion1, - model_config=config, - load_safety_checker=False, - output_path=output_path, - ) - else: - return model_path - - -class StableDiffusion2ModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusion2Model(DiffusersModel): - # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusion2ModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion2 - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion2, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusion2ModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusion2ModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") - - else: - raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if ckpt_config_path is None: - ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if not os.path.exists(model_path): - raise ModelNotFoundException() - - if os.path.isdir(model_path): - if os.path.exists(os.path.join(model_path, "model_index.json")): - return StableDiffusion2ModelFormat.Diffusers - - if os.path.isfile(model_path): - if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return StableDiffusion2ModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {model_path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - return _convert_ckpt_and_cache( - version=BaseModelType.StableDiffusion2, - model_config=config, - output_path=output_path, - ) - else: - return model_path - - -# TODO: rework -# pass precision - currently defaulting to fp16 -def _convert_ckpt_and_cache( - version: BaseModelType, - model_config: Union[ - StableDiffusion1Model.CheckpointConfig, - StableDiffusion2Model.CheckpointConfig, - StableDiffusionXLModel.CheckpointConfig, - ], - output_path: str, - use_save_model: bool = False, - **kwargs, -) -> str: - """ - Convert the checkpoint model indicated in mconfig into a - diffusers, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - - weights = app_config.models_path / model_config.path - config_file = app_config.root_path / model_config.config - output_path = Path(output_path) - variant = model_config.variant - pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline - - # return cached version if it exists - if output_path.exists(): - return output_path - - # to avoid circular import errors - from ...util.devices import choose_torch_device, torch_dtype - from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers - - model_base_to_model_type = { - BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", - BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", - BaseModelType.StableDiffusionXL: "SDXL", - BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", - } - logger.info(f"Converting {weights} to diffusers format") - with SilenceWarnings(): - convert_ckpt_to_diffusers( - weights, - output_path, - model_type=model_base_to_model_type[version], - model_version=version, - model_variant=model_config.variant, - original_config_file=config_file, - extract_ema=True, - scan_needed=True, - pipeline_class=pipeline_class, - from_safetensors=weights.suffix == ".safetensors", - precision=torch_dtype(choose_torch_device()), - **kwargs, - ) - return output_path - - -def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): - ckpt_configs = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: "v1-inference.yaml", - ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512) - ModelVariantType.Inpaint: "v2-inpainting-inference.yaml", - ModelVariantType.Depth: "v2-midas-inference.yaml", - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - ModelVariantType.Inpaint: None, - ModelVariantType.Depth: None, - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - ModelVariantType.Inpaint: None, - ModelVariantType.Depth: None, - }, - } - - app_config = InvokeAIAppConfig.get_config() - try: - config_path = app_config.legacy_conf_path / ckpt_configs[version][variant] - if config_path.is_relative_to(app_config.root_path): - config_path = config_path.relative_to(app_config.root_path) - return str(config_path) - - except Exception: - return None diff --git a/invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py b/invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py deleted file mode 100644 index 2d0dd22c43a..00000000000 --- a/invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py +++ /dev/null @@ -1,150 +0,0 @@ -from enum import Enum -from typing import Literal - -from diffusers import OnnxRuntimeModel - -from .base import ( - BaseModelType, - DiffusersModel, - IAIOnnxRuntimeModel, - ModelConfigBase, - ModelType, - ModelVariantType, - SchedulerPredictionType, - classproperty, -) - - -class StableDiffusionOnnxModelFormat(str, Enum): - Olive = "olive" - Onnx = "onnx" - - -class ONNXStableDiffusion1Model(DiffusersModel): - class Config(ModelConfigBase): - model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion1 - assert model_type == ModelType.ONNX - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.ONNX, - ) - - for child_name, child_type in self.child_types.items(): - if child_type is OnnxRuntimeModel: - self.child_types[child_name] = IAIOnnxRuntimeModel - - # TODO: check that no optimum models provided - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - in_channels = 4 # TODO: - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 1.* model format") - - return cls.create_config( - path=path, - model_format=model_format, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - # TODO: Detect onnx vs olive - return StableDiffusionOnnxModelFormat.Onnx - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path - - -class ONNXStableDiffusion2Model(DiffusersModel): - # TODO: check that configs overwriten properly - class Config(ModelConfigBase): - model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion2 - assert model_type == ModelType.ONNX - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion2, - model_type=ModelType.ONNX, - ) - - for child_name, child_type in self.child_types.items(): - if child_type is OnnxRuntimeModel: - self.child_types[child_name] = IAIOnnxRuntimeModel - # TODO: check that no optimum models provided - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - in_channels = 4 # TODO: - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if variant == ModelVariantType.Normal: - prediction_type = SchedulerPredictionType.VPrediction - upcast_attention = True - - else: - prediction_type = SchedulerPredictionType.Epsilon - upcast_attention = False - - return cls.create_config( - path=path, - model_format=model_format, - variant=variant, - prediction_type=prediction_type, - upcast_attention=upcast_attention, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - # TODO: Detect onnx vs olive - return StableDiffusionOnnxModelFormat.Onnx - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path diff --git a/invokeai/backend/model_management_OLD/models/t2i_adapter.py b/invokeai/backend/model_management_OLD/models/t2i_adapter.py deleted file mode 100644 index 4adb9901f99..00000000000 --- a/invokeai/backend/model_management_OLD/models/t2i_adapter.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -from enum import Enum -from typing import Literal, Optional - -import torch -from diffusers import T2IAdapter - -from invokeai.backend.model_management.models.base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class T2IAdapterModelFormat(str, Enum): - Diffusers = "diffusers" - - -class T2IAdapterModel(ModelBase): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[T2IAdapterModelFormat.Diffusers] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.T2IAdapter - super().__init__(model_path, base_model, model_type) - - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - - model_class_name = config.get("_class_name", None) - if model_class_name not in {"T2IAdapter"}: - raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.") - - self.model_class = self._hf_definition_to_type(["diffusers", model_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> T2IAdapter: - if child_type is not None: - raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.") - - model = None - for variant in ["fp16", None]: - try: - model = self.model_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - variant=variant, - ) - break - except Exception: - pass - if not model: - raise ModelNotFoundException() - - # Calculate a more accurate size after loading the model into memory. - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException(f"Model not found at '{path}'.") - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return T2IAdapterModelFormat.Diffusers - - raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == T2IAdapterModelFormat.Diffusers: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") diff --git a/invokeai/backend/model_management_OLD/models/textual_inversion.py b/invokeai/backend/model_management_OLD/models/textual_inversion.py deleted file mode 100644 index 99358704b8d..00000000000 --- a/invokeai/backend/model_management_OLD/models/textual_inversion.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -from typing import Optional - -import torch - -# TODO: naming -from ..lora import TextualInversionModel as TextualInversionModelRaw -from .base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - classproperty, -) - - -class TextualInversionModel(ModelBase): - # model_size: int - - class Config(ModelConfigBase): - model_format: None - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.TextualInversion - super().__init__(model_path, base_model, model_type) - - self.model_size = os.path.getsize(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - - checkpoint_path = self.model_path - if os.path.isdir(checkpoint_path): - checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin") - - if not os.path.exists(checkpoint_path): - raise ModelNotFoundException() - - model = TextualInversionModelRaw.from_checkpoint( - file_path=checkpoint_path, - dtype=torch_dtype, - ) - - self.model_size = model.embedding.nelement() * model.embedding.element_size() - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "learned_embeds.bin")): - return None # diffusers-ti - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]): - return None - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path diff --git a/invokeai/backend/model_management_OLD/models/vae.py b/invokeai/backend/model_management_OLD/models/vae.py deleted file mode 100644 index 8cc37e67a73..00000000000 --- a/invokeai/backend/model_management_OLD/models/vae.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -from enum import Enum -from pathlib import Path -from typing import Optional - -import safetensors -import torch -from omegaconf import OmegaConf - -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - ModelVariantType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class VaeModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class VaeModel(ModelBase): - # vae_class: Type - # model_size: int - - class Config(ModelConfigBase): - model_format: VaeModelFormat - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Vae - super().__init__(model_path, base_model, model_type) - - try: - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - # config = json.loads(os.path.join(self.model_path, "config.json")) - except Exception: - raise Exception("Invalid vae model! (config.json not found or invalid)") - - try: - vae_class_name = config.get("_class_name", "AutoencoderKL") - self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - except Exception: - raise Exception("Invalid vae model! (Unkown vae type)") - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in vae model") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in vae model") - - model = self.vae_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - ) - # calc more accurate size - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException(f"Does not exist as local file: {path}") - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return VaeModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return VaeModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, # empty config or config of parent model - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == VaeModelFormat.Checkpoint: - return _convert_vae_ckpt_and_cache( - weights_path=model_path, - output_path=output_path, - base_model=base_model, - model_config=config, - ) - else: - return model_path - - -# TODO: rework -def _convert_vae_ckpt_and_cache( - weights_path: str, - output_path: str, - base_model: BaseModelType, - model_config: ModelConfigBase, -) -> str: - """ - Convert the VAE indicated in mconfig into a diffusers AutoencoderKL - object, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - weights_path = app_config.root_dir / weights_path - output_path = Path(output_path) - - """ - this size used only in when tiling enabled to separate input in tiles - sizes in configs from stable diffusion githubs(1 and 2) set to 256 - on huggingface it: - 1.5 - 512 - 1.5-inpainting - 256 - 2-inpainting - 512 - 2-depth - 256 - 2-base - 512 - 2 - 768 - 2.1-base - 768 - 2.1 - 768 - """ - image_size = 512 - - # return cached version if it exists - if output_path.exists(): - return output_path - - if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: - from .stable_diffusion import _select_ckpt_config - - # all sd models use same vae settings - config_file = _select_ckpt_config(base_model, ModelVariantType.Normal) - else: - raise Exception(f"Vae conversion not supported for model type: {base_model}") - - # this avoids circular import error - from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers - - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") - else: - checkpoint = torch.load(weights_path, map_location="cpu") - - # sometimes weights are hidden under "state_dict", and sometimes not - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - config = OmegaConf.load(app_config.root_path / config_file) - - vae_model = convert_ldm_vae_to_diffusers( - checkpoint=checkpoint, - vae_config=config, - image_size=image_size, - ) - vae_model.save_pretrained(output_path, safe_serialization=True) - return output_path diff --git a/invokeai/backend/model_management_OLD/seamless.py b/invokeai/backend/model_management_OLD/seamless.py deleted file mode 100644 index fb9112b56dc..00000000000 --- a/invokeai/backend/model_management_OLD/seamless.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -from contextlib import contextmanager -from typing import Callable, List, Union - -import torch.nn as nn -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel - - -def _conv_forward_asymmetric(self, input, weight, bias): - """ - Patch for Conv2d._conv_forward that supports asymmetric padding - """ - working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) - working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) - return nn.functional.conv2d( - working, - weight, - bias, - self.stride, - nn.modules.utils._pair(0), - self.dilation, - self.groups, - ) - - -@contextmanager -def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): - # Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor - to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] - try: - # Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence - skipped_layers = 1 - for m_name, m in model.named_modules(): - if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - continue - - if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: - # down_blocks.1.resnets.1.conv1 - _, block_num, _, resnet_num, submodule_name = m_name.split(".") - block_num = int(block_num) - resnet_num = int(resnet_num) - - if block_num >= len(model.down_blocks) - skipped_layers: - continue - - # Skip the second resnet (could be configurable) - if resnet_num > 0: - continue - - # Skip Conv2d layers (could be configurable) - if submodule_name == "conv2": - continue - - m.asymmetric_padding_mode = {} - m.asymmetric_padding = {} - m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" - m.asymmetric_padding["x"] = ( - m._reversed_padding_repeated_twice[0], - m._reversed_padding_repeated_twice[1], - 0, - 0, - ) - m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" - m.asymmetric_padding["y"] = ( - 0, - 0, - m._reversed_padding_repeated_twice[2], - m._reversed_padding_repeated_twice[3], - ) - - to_restore.append((m, m._conv_forward)) - m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) - - yield - - finally: - for module, orig_conv_forward in to_restore: - module._conv_forward = orig_conv_forward - if hasattr(module, "asymmetric_padding_mode"): - del module.asymmetric_padding_mode - if hasattr(module, "asymmetric_padding"): - del module.asymmetric_padding diff --git a/invokeai/backend/model_management_OLD/util.py b/invokeai/backend/model_management_OLD/util.py deleted file mode 100644 index f4737d9f0b5..00000000000 --- a/invokeai/backend/model_management_OLD/util.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2023 The InvokeAI Development Team -"""Utilities used by the Model Manager""" - - -def lora_token_vector_length(checkpoint: dict) -> int: - """ - Given a checkpoint in memory, return the lora token vector length - - :param checkpoint: The checkpoint - """ - - def _get_shape_1(key: str, tensor, checkpoint) -> int: - lora_token_vector_length = None - - if "." not in key: - return lora_token_vector_length # wrong key format - model_key, lora_key = key.split(".", 1) - - # check lora/locon - if lora_key == "lora_down.weight": - lora_token_vector_length = tensor.shape[1] - - # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) - elif lora_key in ["hada_w1_b", "hada_w2_b"]: - lora_token_vector_length = tensor.shape[1] - - # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) - elif "lokr_" in lora_key: - if model_key + ".lokr_w1" in checkpoint: - _lokr_w1 = checkpoint[model_key + ".lokr_w1"] - elif model_key + "lokr_w1_b" in checkpoint: - _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"] - else: - return lora_token_vector_length # unknown format - - if model_key + ".lokr_w2" in checkpoint: - _lokr_w2 = checkpoint[model_key + ".lokr_w2"] - elif model_key + "lokr_w2_b" in checkpoint: - _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"] - else: - return lora_token_vector_length # unknown format - - lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] - - elif lora_key == "diff": - lora_token_vector_length = tensor.shape[1] - - # ia3 can be detected only by shape[0] in text encoder - elif lora_key == "weight" and "lora_unet_" not in model_key: - lora_token_vector_length = tensor.shape[0] - - return lora_token_vector_length - - lora_token_vector_length = None - lora_te1_length = None - lora_te2_length = None - for key, tensor in checkpoint.items(): - if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): - lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) - elif key.startswith("lora_unet_") and ( - "time_emb_proj.lora_down" in key - ): # recognizes format at https://civitai.com/models/224641 - lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) - elif key.startswith("lora_te") and "_self_attn_" in key: - tmp_length = _get_shape_1(key, tensor, checkpoint) - if key.startswith("lora_te_"): - lora_token_vector_length = tmp_length - elif key.startswith("lora_te1_"): - lora_te1_length = tmp_length - elif key.startswith("lora_te2_"): - lora_te2_length = tmp_length - - if lora_te1_length is not None and lora_te2_length is not None: - lora_token_vector_length = lora_te1_length + lora_te2_length - - if lora_token_vector_length is not None: - break - - return lora_token_vector_length diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 98cc5054c73..88356d04686 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,5 +1,4 @@ """Re-export frequently-used symbols from the Model Manager backend.""" - from .config import ( AnyModel, AnyModelConfig, @@ -33,3 +32,42 @@ "SchedulerPredictionType", "SubModelType", ] + +########## to help populate the openapi_schema with format enums for each config ########### +# This code is no longer necessary? +# leave it here just in case +# +# import inspect +# from enum import Enum +# from typing import Any, Iterable, Dict, get_args, Set +# def _expand(something: Any) -> Iterable[type]: +# if isinstance(something, type): +# yield something +# else: +# for x in get_args(something): +# for y in _expand(x): +# yield y + +# def _find_format(cls: type) -> Iterable[Enum]: +# if hasattr(inspect, "get_annotations"): +# fields = inspect.get_annotations(cls) +# else: +# fields = cls.__annotations__ +# if "format" in fields: +# for x in get_args(fields["format"]): +# yield x +# for parent_class in cls.__bases__: +# for x in _find_format(parent_class): +# yield x +# return None + +# def get_model_config_formats() -> Dict[str, Set[Enum]]: +# result: Dict[str, Set[Enum]] = {} +# for model_config in _expand(AnyModelConfig): +# for field in _find_format(model_config): +# if field is None: +# continue +# if not result.get(model_config.__qualname__): +# result[model_config.__qualname__] = set() +# result[model_config.__qualname__].add(field) +# return result diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index a3a840b6259..a0421017db9 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -6,12 +6,22 @@ from pathlib import Path from .convert_cache.convert_cache_default import ModelConvertCache -from .load_base import AnyModelLoader, LoadedModel +from .load_base import LoadedModel, ModelLoaderBase +from .load_default import ModelLoader from .model_cache.model_cache_default import ModelCache +from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase # This registers the subclasses that implement loaders of specific model types loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] for module in loaders: import_module(f"{__package__}.model_loaders.{module}") -__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] +__all__ = [ + "LoadedModel", + "ModelCache", + "ModelConvertCache", + "ModelLoaderBase", + "ModelLoader", + "ModelLoaderRegistryBase", + "ModelLoaderRegistry", +] diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 4c5e899aa3b..b8ce56eb16d 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -1,37 +1,22 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """ Base class for model loading in InvokeAI. - -Use like this: - - loader = AnyModelLoader(...) - loaded_model = loader.get_model('019ab39adfa1840455') - with loaded_model as model: # context manager moves model into VRAM - # do something with loaded_model """ -import hashlib from abc import ABC, abstractmethod from dataclasses import dataclass from logging import Logger from pathlib import Path -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any, Optional from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager.config import ( AnyModel, AnyModelConfig, - BaseModelType, - ModelConfigBase, - ModelFormat, - ModelType, SubModelType, - VaeCheckpointConfig, - VaeDiffusersConfig, ) from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.util.logging import InvokeAILogger @dataclass @@ -56,6 +41,14 @@ def model(self) -> AnyModel: return self._locker.model +# TODO(MM2): +# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't +# know about. I think the problem may be related to this class being an ABC. +# +# For example, GenericDiffusersLoader defines `get_hf_load_class()`, and StableDiffusionDiffusersModel attempts to +# call it. However, the method is not defined in the ABC, so it is not guaranteed to be implemented. + + class ModelLoaderBase(ABC): """Abstract base class for loading models into RAM/VRAM.""" @@ -71,7 +64,7 @@ def __init__( pass @abstractmethod - def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its confguration. @@ -90,106 +83,3 @@ def get_size_fs( ) -> int: """Return size in bytes of the model, calculated before loading.""" pass - - -# TO DO: Better name? -class AnyModelLoader: - """This class manages the model loaders and invokes the correct one to load a model of given base and type.""" - - # this tracks the loader subclasses - _registry: Dict[str, Type[ModelLoaderBase]] = {} - _logger: Logger = InvokeAILogger.get_logger() - - def __init__( - self, - app_config: InvokeAIAppConfig, - logger: Logger, - ram_cache: ModelCacheBase[AnyModel], - convert_cache: ModelConvertCacheBase, - ): - """Initialize AnyModelLoader with its dependencies.""" - self._app_config = app_config - self._logger = logger - self._ram_cache = ram_cache - self._convert_cache = convert_cache - - @property - def ram_cache(self) -> ModelCacheBase[AnyModel]: - """Return the RAM cache associated used by the loaders.""" - return self._ram_cache - - @property - def convert_cache(self) -> ModelConvertCacheBase: - """Return the convert cache associated used by the loaders.""" - return self._convert_cache - - def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """ - Return a model given its configuration. - - :param key: model key, as known to the config backend - :param submodel_type: an ModelType enum indicating the portion of - the model to retrieve (e.g. ModelType.Vae) - """ - implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type) - return implementation( - app_config=self._app_config, - logger=self._logger, - ram_cache=self._ram_cache, - convert_cache=self._convert_cache, - ).load_model(model_config, submodel_type) - - @staticmethod - def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: - return "-".join([base.value, type.value, format.value]) - - @classmethod - def get_implementation( - cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: - """Get subclass of ModelLoaderBase registered to handle base and type.""" - # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned - conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) - - key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type - key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any - implementation = cls._registry.get(key1) or cls._registry.get(key2) - if not implementation: - raise NotImplementedError( - f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" - ) - return implementation, conf2, submodel_type - - @classmethod - def _handle_subtype_overrides( - cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] - ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: - if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: - model_path = Path(config.vae) - config_class = ( - VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig - ) - hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest() - new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash) - submodel_type = None - else: - new_conf = config - return new_conf, submodel_type - - @classmethod - def register( - cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any - ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: - """Define a decorator which registers the subclass of loader.""" - - def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: - cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") - key = cls._to_registry_key(base, type, format) - if key in cls._registry: - raise Exception( - f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" - ) - cls._registry[key] = subclass - return subclass - - return decorator diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 79c9311de1d..642cffaf4be 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -1,13 +1,9 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Default implementation of model loading in InvokeAI.""" -import sys from logging import Logger from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from diffusers import ModelMixin -from diffusers.configuration_utils import ConfigMixin +from typing import Optional, Tuple from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager import ( @@ -25,17 +21,6 @@ from invokeai.backend.util.devices import choose_torch_device, torch_dtype -class ConfigLoader(ConfigMixin): - """Subclass of ConfigMixin for loading diffusers configuration files.""" - - @classmethod - def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: - """Load a diffusrs ConfigMixin configuration.""" - cls.config_name = kwargs.pop("config_name") - # Diffusers doesn't provide typing info - return super().load_config(*args, **kwargs) # type: ignore - - # TO DO: The loader is not thread safe! class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" @@ -137,43 +122,6 @@ def get_size_fs( variant=config.repo_variant if hasattr(config, "repo_variant") else None, ) - def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: - return ConfigLoader.load_config(model_path, config_name=config_name) - - # TO DO: Add exception handling - def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type - if module in ["diffusers", "transformers"]: - res_type = sys.modules[module] - else: - res_type = sys.modules["diffusers"].pipelines - result: ModelMixin = getattr(res_type, class_name) - return result - - # TO DO: Add exception handling - def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: - if submodel_type: - try: - config = self._load_diffusers_config(model_path, config_name="model_index.json") - module, class_name = config[submodel_type.value] - return self._hf_definition_to_type(module=module, class_name=class_name) - except KeyError as e: - raise InvalidModelConfigException( - f'The "{submodel_type}" submodel is not available for this model.' - ) from e - else: - try: - config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config.get("_class_name", None) - if class_name: - return self._hf_definition_to_type(module="diffusers", class_name=class_name) - if config.get("model_type", None) == "clip_vision_model": - class_name = config.get("architectures")[0] - return self._hf_definition_to_type(module="transformers", class_name=class_name) - if not class_name: - raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") - except KeyError as e: - raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e - # This needs to be implemented in subclasses that handle checkpoints def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: raise NotImplementedError diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 209d7166f36..195e39361b4 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -55,7 +55,7 @@ def capture(cls, run_garbage_collector: bool = True) -> Self: vram = None try: - malloc_info = LibcUtil().mallinfo2() # type: ignore + malloc_info = LibcUtil().mallinfo2() except (OSError, AttributeError): # OSError: This is expected in environments that do not have the 'libc.so.6' shared library. # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) diff --git a/invokeai/backend/model_manager/load/model_loader_registry.py b/invokeai/backend/model_manager/load/model_loader_registry.py new file mode 100644 index 00000000000..ce1110e749b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loader_registry.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +""" +This module implements a system in which model loaders register the +type, base and format of models that they know how to load. + +Use like this: + + cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore + loaded_model = cls( + app_config=app_config, + logger=logger, + ram_cache=ram_cache, + convert_cache=convert_cache + ).load_model(model_config, submodel_type) + +""" +import hashlib +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Dict, Optional, Tuple, Type + +from ..config import ( + AnyModelConfig, + BaseModelType, + ModelConfigBase, + ModelFormat, + ModelType, + SubModelType, + VaeCheckpointConfig, + VaeDiffusersConfig, +) +from . import ModelLoaderBase + + +class ModelLoaderRegistryBase(ABC): + """This class allows model loaders to register their type, base and format.""" + + @classmethod + @abstractmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + @classmethod + @abstractmethod + def get_implementation( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + """ + Get subclass of ModelLoaderBase registered to handle base and type. + + Parameters: + :param config: Model configuration record, as returned by ModelRecordService + :param submodel_type: Submodel to fetch (main models only) + :return: tuple(loader_class, model_config, submodel_type) + + Note that the returned model config may be different from one what passed + in, in the event that a submodel type is provided. + """ + + +class ModelLoaderRegistry: + """ + This class allows model loaders to register their type, base and format. + """ + + _registry: Dict[str, Type[ModelLoaderBase]] = {} + + @classmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: + key = cls._to_registry_key(base, type, format) + if key in cls._registry: + raise Exception( + f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" + ) + cls._registry[key] = subclass + return subclass + + return decorator + + @classmethod + def get_implementation( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + """Get subclass of ModelLoaderBase registered to handle base and type.""" + # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned + conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) + + key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any + implementation = cls._registry.get(key1) or cls._registry.get(key2) + if not implementation: + raise NotImplementedError( + f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" + ) + return implementation, conf2, submodel_type + + @classmethod + def _handle_subtype_overrides( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: + if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: + model_path = Path(config.vae) + config_class = ( + VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig + ) + hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest() + new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash) + submodel_type = None + else: + new_conf = config + return new_conf, submodel_type + + @staticmethod + def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: + return "-".join([base.value, type.value, format.value]) diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index d446d079336..43393f5a847 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -13,13 +13,13 @@ ModelType, ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers -from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from .. import ModelLoaderRegistry from .generic_diffusers import GenericDiffusersLoader -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) class ControlnetLoader(GenericDiffusersLoader): """Class to load ControlNet models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 114e317f3c6..9a9b25aec53 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -1,24 +1,27 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Class for simple diffusers model loading in InvokeAI.""" +import sys from pathlib import Path -from typing import Optional +from typing import Any, Dict, Optional + +from diffusers import ConfigMixin, ModelMixin from invokeai.backend.model_manager import ( AnyModel, BaseModelType, + InvalidModelConfigException, ModelFormat, ModelRepoVariant, ModelType, SubModelType, ) -from ..load_base import AnyModelLoader -from ..load_default import ModelLoader +from .. import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) class GenericDiffusersLoader(ModelLoader): """Class to load simple diffusers models.""" @@ -28,9 +31,60 @@ def _load_model( model_variant: Optional[ModelRepoVariant] = None, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - model_class = self._get_hf_load_class(model_path) + model_class = self.get_hf_load_class(model_path) if submodel_type is not None: raise Exception(f"There are no submodels in models of type {model_class}") variant = model_variant.value if model_variant else None result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore return result + + # TO DO: Add exception handling + def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: + """Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load.""" + if submodel_type: + try: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + result = self._hf_definition_to_type(module=module, class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException( + f'The "{submodel_type}" submodel is not available for this model.' + ) from e + else: + try: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config.get("_class_name", None) + if class_name: + result = self._hf_definition_to_type(module="diffusers", class_name=class_name) + if config.get("model_type", None) == "clip_vision_model": + class_name = config.get("architectures") + assert class_name is not None + result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) + if not class_name: + raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") + except KeyError as e: + raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e + return result + + # TO DO: Add exception handling + def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type + if module in ["diffusers", "transformers"]: + res_type = sys.modules[module] + else: + res_type = sys.modules["diffusers"].pipelines + result: ModelMixin = getattr(res_type, class_name) + return result + + def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: + return ConfigLoader.load_config(model_path, config_name=config_name) + + +class ConfigLoader(ConfigMixin): + """Subclass of ConfigMixin for loading diffusers configuration files.""" + + @classmethod + def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Load a diffusrs ConfigMixin configuration.""" + cls.config_name = kwargs.pop("config_name") + # Diffusers doesn't provide typing info + return super().load_config(*args, **kwargs) # type: ignore diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py index 27ced41c1e9..7d25e9d218c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -15,11 +15,10 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) class IPAdapterInvokeAILoader(ModelLoader): """Class to load IP Adapter diffusers models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 6ff2dcc9182..fe804ef5654 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -18,13 +18,13 @@ SubModelType, ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from .. import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) class LoraLoader(ModelLoader): """Class to load LoRA models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py index 935a6b7c953..38f0274acc6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/onnx.py +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -13,13 +13,14 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +from .. import ModelLoaderRegistry +from .generic_diffusers import GenericDiffusersLoader -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) -class OnnyxDiffusersModel(ModelLoader): + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) +class OnnyxDiffusersModel(GenericDiffusersLoader): """Class to load onnx models.""" def _load_model( @@ -30,7 +31,7 @@ def _load_model( ) -> AnyModel: if not submodel_type is not None: raise Exception("A submodel type must be provided when loading onnx pipelines.") - load_class = self._get_hf_load_class(model_path, submodel_type) + load_class = self.get_hf_load_class(model_path, submodel_type) variant = model_variant.value if model_variant else None model_path = model_path / submodel_type.value result: AnyModel = load_class.from_pretrained( diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index 23b4e1fccd6..5884f84e8da 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -19,13 +19,14 @@ ) from invokeai.backend.model_manager.config import MainCheckpointConfig from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +from .. import ModelLoaderRegistry +from .generic_diffusers import GenericDiffusersLoader -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) -class StableDiffusionDiffusersModel(ModelLoader): + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) +class StableDiffusionDiffusersModel(GenericDiffusersLoader): """Class to load main models.""" model_base_to_model_type = { @@ -43,7 +44,7 @@ def _load_model( ) -> AnyModel: if not submodel_type is not None: raise Exception("A submodel type must be provided when loading main pipelines.") - load_class = self._get_hf_load_class(model_path, submodel_type) + load_class = self.get_hf_load_class(model_path, submodel_type) variant = model_variant.value if model_variant else None model_path = model_path / submodel_type.value result: AnyModel = load_class.from_pretrained( diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 94767479609..094d4d7c5c3 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Optional, Tuple -from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -15,12 +14,15 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.textual_inversion import TextualInversionModelRaw +from .. import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder) + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) +@ModelLoaderRegistry.register( + base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder +) class TextualInversionLoader(ModelLoader): """Class to load TI models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 3983ea75950..7ade1494eb1 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -14,14 +14,14 @@ ModelType, ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers -from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from .. import ModelLoaderRegistry from .generic_diffusers import GenericDiffusersLoader -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) -@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) class VaeLoader(GenericDiffusersLoader): """Class to load VAE models.""" diff --git a/invokeai/backend/model_manager/load/optimizations.py b/invokeai/backend/model_manager/load/optimizations.py index a46d262175f..030fcfa639a 100644 --- a/invokeai/backend/model_manager/load/optimizations.py +++ b/invokeai/backend/model_manager/load/optimizations.py @@ -1,16 +1,16 @@ from contextlib import contextmanager +from typing import Any, Generator import torch -def _no_op(*args, **kwargs): +def _no_op(*args: Any, **kwargs: Any) -> None: pass @contextmanager -def skip_torch_weight_init(): - """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) - to skip weight initialization. +def skip_torch_weight_init() -> Generator[None, None, None]: + """Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization. By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is @@ -18,13 +18,14 @@ def skip_torch_weight_init(): monkey-patches common torch layers to skip the weight initialization step. """ torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] - saved_functions = [m.reset_parameters for m in torch_modules] + saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules] try: for torch_module in torch_modules: + assert hasattr(torch_module, "reset_parameters") torch_module.reset_parameters = _no_op - yield None finally: for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): + assert hasattr(torch_module, "reset_parameters") torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 108f1f0e6f7..7063cb907d2 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -13,7 +13,7 @@ import torch from diffusers import AutoPipelineForText2Image -from diffusers import logging as dlogging +from diffusers.utils import logging as dlogging from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -76,7 +76,7 @@ def merge_diffusion_models( custom_pipeline="checkpoint_merger", torch_dtype=dtype, variant=variant, - ) + ) # type: ignore merged_pipe = pipe.merge( pretrained_model_name_or_path_list=model_paths, alpha=alpha, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 5c3afcdc960..6e410d82220 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel): AllowDifferentLicense: bool = Field( description="if true, derivatives of this model be redistributed under a different license", default=False ) - AllowCommercialUse: CommercialUsage = Field( - description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set + AllowCommercialUse: Optional[CommercialUsage] = Field( + description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None ) @@ -139,7 +139,10 @@ def credit_required(self) -> bool: @property def allow_commercial_use(self) -> bool: """Return True if commercial use is allowed.""" - return self.restrictions.AllowCommercialUse != CommercialUsage("None") + if self.restrictions.AllowCommercialUse is None: + return False + else: + return self.restrictions.AllowCommercialUse != CommercialUsage("None") @property def allow_derivatives(self) -> bool: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index d511ffa875f..7de4289466d 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -8,7 +8,6 @@ from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger -from .util.model_util import lora_token_vector_length, read_checkpoint_meta from invokeai.backend.util.util import SilenceWarnings from .config import ( @@ -23,6 +22,7 @@ SchedulerPredictionType, ) from .hash import FastModelHash +from .util.model_util import lora_token_vector_length, read_checkpoint_meta CkptType = Dict[str, Any] @@ -53,6 +53,7 @@ }, } + class ProbeBase(object): """Base class for probes.""" diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 0ead22b743f..f7ef2e049d4 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - models_found: Optional[Set[Path]] = Field(default=None) - scanned_dirs: Optional[Set[Path]] = Field(default=None) - pruned_paths: Optional[Set[Path]] = Field(default=None) + models_found: Set[Path] = Field(default_factory=set) + scanned_dirs: Set[Path] = Field(default_factory=set) + pruned_paths: Set[Path] = Field(default_factory=set) def search_started(self) -> None: self.models_found = set() diff --git a/invokeai/backend/model_manager/util/libc_util.py b/invokeai/backend/model_manager/util/libc_util.py index 1fbcae0a93c..ef1ac2f8a4b 100644 --- a/invokeai/backend/model_manager/util/libc_util.py +++ b/invokeai/backend/model_manager/util/libc_util.py @@ -35,7 +35,7 @@ class Struct_mallinfo2(ctypes.Structure): ("keepcost", ctypes.c_size_t), ] - def __str__(self): + def __str__(self) -> str: s = "" s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n" s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n" @@ -62,7 +62,7 @@ class LibcUtil: TODO: Improve cross-OS compatibility of this class. """ - def __init__(self): + def __init__(self) -> None: self._libc = ctypes.cdll.LoadLibrary("libc.so.6") def mallinfo2(self) -> Struct_mallinfo2: @@ -72,4 +72,5 @@ def mallinfo2(self) -> Struct_mallinfo2: """ mallinfo2 = self._libc.mallinfo2 mallinfo2.restype = Struct_mallinfo2 - return mallinfo2() + result: Struct_mallinfo2 = mallinfo2() + return result diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index 6847a40878c..2e448520e56 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -1,12 +1,15 @@ """Utilities for parsing model files, used mostly by probe.py""" import json -import torch -from typing import Union from pathlib import Path +from typing import Dict, Optional, Union + +import safetensors +import torch from picklescan.scanner import scan_file_path -def _fast_safetensors_reader(path: str): + +def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]: checkpoint = {} device = torch.device("meta") with open(path, "rb") as f: @@ -37,10 +40,12 @@ def _fast_safetensors_reader(path: str): return checkpoint -def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): + +def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]: if str(path).endswith(".safetensors"): try: - checkpoint = _fast_safetensors_reader(path) + path_str = path.as_posix() if isinstance(path, Path) else path + checkpoint = _fast_safetensors_reader(path_str) except Exception: # TODO: create issue for support "meta"? checkpoint = safetensors.torch.load_file(path, device="cpu") @@ -52,14 +57,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): checkpoint = torch.load(path, map_location=torch.device("meta")) return checkpoint -def lora_token_vector_length(checkpoint: dict) -> int: + +def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: """ Given a checkpoint in memory, return the lora token vector length :param checkpoint: The checkpoint """ - def _get_shape_1(key: str, tensor, checkpoint) -> int: + def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: lora_token_vector_length = None if "." not in key: diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py index 9b2096abdf0..8916865dd52 100644 --- a/invokeai/backend/onnx/onnx_runtime.py +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -8,6 +8,7 @@ import onnx from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers + from ..raw_model import RawModel ONNX_WEIGHTS_NAME = "model.onnx" @@ -15,7 +16,7 @@ # NOTE FROM LS: This was copied from Stalker's original implementation. # I have not yet gone through and fixed all the type hints -class IAIOnnxRuntimeModel: +class IAIOnnxRuntimeModel(RawModel): class _tensor_access: def __init__(self, model): # type: ignore self.model = model diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py index 2e224d538b3..d0dc50c4560 100644 --- a/invokeai/backend/raw_model.py +++ b/invokeai/backend/raw_model.py @@ -10,5 +10,6 @@ that adds additional methods and attributes. """ + class RawModel: """Base class for 'Raw' model wrappers.""" diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py index bfdf9e0c536..fb9112b56dc 100644 --- a/invokeai/backend/stable_diffusion/seamless.py +++ b/invokeai/backend/stable_diffusion/seamless.py @@ -1,10 +1,11 @@ from __future__ import annotations from contextlib import contextmanager -from typing import List, Union +from typing import Callable, List, Union import torch.nn as nn -from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel def _conv_forward_asymmetric(self, input, weight, bias): @@ -26,70 +27,51 @@ def _conv_forward_asymmetric(self, input, weight, bias): @contextmanager def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): + # Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor + to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] try: - to_restore = [] - + # Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence + skipped_layers = 1 for m_name, m in model.named_modules(): - if isinstance(model, UNet2DConditionModel): - if ".attentions." in m_name: - continue + if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + continue - if ".resnets." in m_name: - if ".conv2" in m_name: - continue - if ".conv_shortcut" in m_name: - continue - - """ - if isinstance(model, UNet2DConditionModel): - if False and ".upsamplers." in m_name: - continue + if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: + # down_blocks.1.resnets.1.conv1 + _, block_num, _, resnet_num, submodule_name = m_name.split(".") + block_num = int(block_num) + resnet_num = int(resnet_num) - if False and ".downsamplers." in m_name: + if block_num >= len(model.down_blocks) - skipped_layers: continue - if True and ".resnets." in m_name: - if True and ".conv1" in m_name: - if False and "down_blocks" in m_name: - continue - if False and "mid_block" in m_name: - continue - if False and "up_blocks" in m_name: - continue - - if True and ".conv2" in m_name: - continue - - if True and ".conv_shortcut" in m_name: - continue - - if True and ".attentions." in m_name: + # Skip the second resnet (could be configurable) + if resnet_num > 0: continue - if False and m_name in ["conv_in", "conv_out"]: + # Skip Conv2d layers (could be configurable) + if submodule_name == "conv2": continue - """ - - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - m.asymmetric_padding_mode = {} - m.asymmetric_padding = {} - m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" - m.asymmetric_padding["x"] = ( - m._reversed_padding_repeated_twice[0], - m._reversed_padding_repeated_twice[1], - 0, - 0, - ) - m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" - m.asymmetric_padding["y"] = ( - 0, - 0, - m._reversed_padding_repeated_twice[2], - m._reversed_padding_repeated_twice[3], - ) - to_restore.append((m, m._conv_forward)) - m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) yield diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py index 9a4fa0b5402..f7390979bbc 100644 --- a/invokeai/backend/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -8,8 +8,10 @@ from safetensors.torch import load_file from transformers import CLIPTokenizer from typing_extensions import Self + from .raw_model import RawModel + class TextualInversionModelRaw(RawModel): embedding: torch.Tensor # [n, 768]|[n, 1280] embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index a3def182c8c..0d76c4633cf 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -42,7 +42,7 @@ def install_and_load_model( # If the requested model is already installed, return its LoadedModel with contextlib.suppress(UnknownModelException): # TODO: Replace with wrapper call - loaded_model: LoadedModel = model_manager.load.load_model_by_attr( + loaded_model: LoadedModel = model_manager.load_model_by_attr( model_name=model_name, base_model=base_model, model_type=model_type ) return loaded_model @@ -53,7 +53,7 @@ def install_and_load_model( assert job.complete try: - loaded_model = model_manager.load.load_model_by_config(job.config_out) + loaded_model = model_manager.load_model_by_config(job.config_out) return loaded_model except UnknownModelException as e: raise Exception( diff --git a/tests/backend/model_manager/model_loading/test_model_load.py b/tests/backend/model_manager/model_loading/test_model_load.py index 38d9b8afb8c..c1fde504eae 100644 --- a/tests/backend/model_manager/model_loading/test_model_load.py +++ b/tests/backend/model_manager/model_loading/test_model_load.py @@ -4,18 +4,27 @@ from pathlib import Path -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.app.services.model_load import ModelLoadServiceBase +from invokeai.app.services.model_manager import ModelManagerServiceBase from invokeai.backend.textual_inversion import TextualInversionModelRaw from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 -def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path): - store = mm2_installer.record_store + +def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path): + store = mm2_model_manager.store matches = store.search_by_attr(model_name="test_embedding") assert len(matches) == 0 - key = mm2_installer.register_path(embedding_file) - loaded_model = mm2_loader.load_model_by_config(store.get_model(key)) + key = mm2_model_manager.install.register_path(embedding_file) + loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key)) assert loaded_model is not None assert loaded_model.config.key == key with loaded_model as model: assert isinstance(model, TextualInversionModelRaw) + loaded_model_2 = mm2_model_manager.load_model_by_key(key) + assert loaded_model.config.key == loaded_model_2.config.key + + loaded_model_3 = mm2_model_manager.load_model_by_attr( + model_name=loaded_model.config.name, + model_type=loaded_model.config.type, + base_model=loaded_model.config.base, + ) + assert loaded_model.config.key == loaded_model_3.config.key diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 5f7f44c0188..df54e2f9267 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -6,17 +6,17 @@ from typing import Any, Dict, List import pytest -from pytest import FixtureRequest from pydantic import BaseModel +from pytest import FixtureRequest from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService +from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase -from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService -from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase +from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase +from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( @@ -95,9 +95,7 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: @pytest.fixture -def mm2_download_queue(mm2_session: Session, - request: FixtureRequest - ) -> DownloadQueueServiceBase: +def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase: download_queue = DownloadQueueService(requests_session=mm2_session) download_queue.start() @@ -107,30 +105,34 @@ def stop_queue() -> None: request.addfinalizer(stop_queue) return download_queue + @pytest.fixture def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: return mm2_record_store.metadata_store + @pytest.fixture def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: ram_cache = ModelCache( logger=InvokeAILogger.get_logger(), max_cache_size=mm2_app_config.ram_cache_size, - max_vram_cache_size=mm2_app_config.vram_cache_size + max_vram_cache_size=mm2_app_config.vram_cache_size, ) convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) - return ModelLoadService(app_config=mm2_app_config, - record_store=mm2_record_store, - ram_cache=ram_cache, - convert_cache=convert_cache, - ) + return ModelLoadService( + app_config=mm2_app_config, + ram_cache=ram_cache, + convert_cache=convert_cache, + ) + @pytest.fixture -def mm2_installer(mm2_app_config: InvokeAIAppConfig, - mm2_download_queue: DownloadQueueServiceBase, - mm2_session: Session, - request: FixtureRequest, - ) -> ModelInstallServiceBase: +def mm2_installer( + mm2_app_config: InvokeAIAppConfig, + mm2_download_queue: DownloadQueueServiceBase, + mm2_session: Session, + request: FixtureRequest, +) -> ModelInstallServiceBase: logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) events = DummyEventService() @@ -213,15 +215,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas store.add_model("test_config_5", raw5) return store + @pytest.fixture -def mm2_model_manager(mm2_record_store: ModelRecordServiceBase, - mm2_installer: ModelInstallServiceBase, - mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase: - return ModelManagerService( - store=mm2_record_store, - install=mm2_installer, - load=mm2_loader - ) +def mm2_model_manager( + mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase +) -> ModelManagerServiceBase: + return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader) + @pytest.fixture def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: @@ -306,5 +306,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: ), ) return sess - - diff --git a/tests/backend/model_manager/test_lora.py b/tests/backend/model_manager/test_lora.py index e124bb68efc..114a4cfdcff 100644 --- a/tests/backend/model_manager/test_lora.py +++ b/tests/backend/model_manager/test_lora.py @@ -5,8 +5,8 @@ import pytest import torch -from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.lora import LoRALayer, LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher @pytest.mark.parametrize( diff --git a/tests/backend/model_manager/test_memory_snapshot.py b/tests/backend/model_manager/test_memory_snapshot.py index 87ec8c34ee0..d31ae79b668 100644 --- a/tests/backend/model_manager/test_memory_snapshot.py +++ b/tests/backend/model_manager/test_memory_snapshot.py @@ -1,7 +1,8 @@ import pytest -from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2 from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2 + def test_memory_snapshot_capture(): """Smoke test of MemorySnapshot.capture()."""