Skip to content

Commit dea899f

Browse files
Unload weights if vram usage goes up between runs. (#10690)
1 parent e632e5d commit dea899f

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

comfy/model_management.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,11 @@ def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
503503
use_more_vram = lowvram_model_memory
504504
if use_more_vram == 0:
505505
use_more_vram = 1e32
506-
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
506+
if use_more_vram > 0:
507+
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
508+
else:
509+
self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights)
510+
507511
real_model = self.model.model
508512

509513
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@@ -689,7 +693,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
689693
current_free_mem = get_free_memory(torch_dev) + loaded_memory
690694

691695
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
692-
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
696+
lowvram_model_memory = lowvram_model_memory - loaded_memory
697+
698+
if lowvram_model_memory == 0:
699+
lowvram_model_memory = 0.1
693700

694701
if vram_set_state == VRAMState.NO_VRAM:
695702
lowvram_model_memory = 0.1

comfy/model_patcher.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
843843

844844
self.object_patches_backup.clear()
845845

846-
def partially_unload(self, device_to, memory_to_free=0):
846+
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
847847
with self.use_ejected():
848848
hooks_unpatched = False
849849
memory_freed = 0
@@ -887,13 +887,19 @@ def partially_unload(self, device_to, memory_to_free=0):
887887
module_mem += move_weight_functions(m, device_to)
888888
if lowvram_possible:
889889
if weight_key in self.patches:
890-
_, set_func, convert_func = get_key_weight(self.model, weight_key)
891-
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
892-
patch_counter += 1
890+
if force_patch_weights:
891+
self.patch_weight_to_device(weight_key)
892+
else:
893+
_, set_func, convert_func = get_key_weight(self.model, weight_key)
894+
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
895+
patch_counter += 1
893896
if bias_key in self.patches:
894-
_, set_func, convert_func = get_key_weight(self.model, bias_key)
895-
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
896-
patch_counter += 1
897+
if force_patch_weights:
898+
self.patch_weight_to_device(bias_key)
899+
else:
900+
_, set_func, convert_func = get_key_weight(self.model, bias_key)
901+
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
902+
patch_counter += 1
897903
cast_weight = True
898904

899905
if cast_weight:

0 commit comments

Comments
 (0)