From 76010a51ef1f3805a7487723599035bc2356c3fb Mon Sep 17 00:00:00 2001 From: wangqiuwen Date: Sat, 7 Oct 2023 15:36:01 +0800 Subject: [PATCH 1/2] up --- modules/sd_models.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index eedb38c65ad..3a060ab685a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,4 +1,5 @@ import collections +import copy import os.path import sys import gc @@ -309,8 +310,6 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") - # move to end as latest - checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") @@ -352,12 +351,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if model.is_sdxl: sd_models_xl.extend_sdxl(model) - model.load_state_dict(state_dict, strict=False) - timer.record("apply weights to model") - if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = state_dict + checkpoints_loaded[checkpoint_info] = copy.deepcopy(state_dict) + + model.load_state_dict(state_dict, strict=False) + timer.record("apply weights to model") del state_dict From 770ee23f18d12fb3b5627c636aa420f481e292ee Mon Sep 17 00:00:00 2001 From: wangqiuwen Date: Sat, 7 Oct 2023 15:38:50 +0800 Subject: [PATCH 2/2] reverst --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3a060ab685a..8d63e7f10f7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -310,6 +310,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") + # move to end as latest + checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")