From 0b58299ab053f9e4caf0c4bc4f2f8796c7fe1504 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 9 Nov 2025 16:02:32 -0500 Subject: [PATCH] Unload weights if vram usage goes up between runs. --- comfy/model_management.py | 11 +++++++++-- comfy/model_patcher.py | 20 +++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7012df858ab2..a4410f2ece53 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -503,7 +503,11 @@ def model_load(self, lowvram_model_memory=0, force_patch_weights=False): use_more_vram = lowvram_model_memory if use_more_vram == 0: use_more_vram = 1e32 - self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + if use_more_vram > 0: + self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + else: + self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: @@ -689,7 +693,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu current_free_mem = get_free_memory(torch_dev) + loaded_memory lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) - lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) + lowvram_model_memory = lowvram_model_memory - loaded_memory + + if lowvram_model_memory == 0: + lowvram_model_memory = 0.1 if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 0.1 diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5a31a8734b07..04ce02acc2d6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -843,7 +843,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.object_patches_backup.clear() - def partially_unload(self, device_to, memory_to_free=0): + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 @@ -887,13 +887,19 @@ def partially_unload(self, device_to, memory_to_free=0): module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, weight_key) - m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) + patch_counter += 1 if bias_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, bias_key) - m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) + patch_counter += 1 cast_weight = True if cast_weight: