Skip to content

Commit

Permalink
Improved memory management. (#5450)
Browse files Browse the repository at this point in the history
* Less fragile memory management.

* Fix issue.

* Remove useless function.

* Prevent and detect some types of memory leaks.

* Run garbage collector when switching workflow if needed.

* Fix issue.
  • Loading branch information
comfyanonymous authored Dec 2, 2024
1 parent 2d5b3e0 commit 79d5cea
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 119 deletions.
198 changes: 85 additions & 113 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import torch
import sys
import platform
import weakref
import gc

class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
Expand Down Expand Up @@ -287,11 +289,27 @@ def module_size(module):

class LoadedModel:
def __init__(self, model):
self.model = model
self._set_model(model)
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
self.currently_used = True
self.model_finalizer = None
self._patcher_finalizer = None

def _set_model(self, model):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)

def _switch_parent(self):
model = self._parent_model()
if model is not None:
self._set_model(model)

@property
def model(self):
return self._model()

def model_memory(self):
return self.model.model_size()
Expand All @@ -306,32 +324,23 @@ def model_memory_required(self, device):
return self.model_memory()

def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device

self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype())

load_weights = not self.weights_loaded
# if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model

if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram)
else:
try:
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e

if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)

self.weights_loaded = True
return self.real_model
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
return real_model

def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
Expand All @@ -344,18 +353,23 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True):
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True

def model_use_more_vram(self, extra_memory):
return self.model.partially_load(self.device, extra_memory)
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)

def __eq__(self, other):
return self.model is other.model

def __del__(self):
if self._patcher_finalizer is not None:
self._patcher_finalizer.detach()


def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models:
if m.device == device:
Expand Down Expand Up @@ -386,38 +400,8 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()

def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload

if len(to_unload) == 0:
return True

same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1

if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True

if not force_unload:
if unload_weights_only and unload_weight == False:
return None
else:
unload_weight = True

for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)

return unload_weight

def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []
can_unload = []
unloaded_models = []
Expand Down Expand Up @@ -454,6 +438,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
return unloaded_models

def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state

inference_memory = minimum_inference_memory()
Expand All @@ -466,63 +451,45 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models = set(models)

models_to_load = []
models_already_loaded = []

for x in models:
loaded_model = LoadedModel(x)
loaded = None

try:
loaded_model_index = current_loaded_models.index(loaded_model)
except:
loaded_model_index = None

if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)

if loaded is None:
loaded.currently_used = True
models_to_load.append(loaded)
else:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)

if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
free_mem = get_free_memory(d)
if free_mem < minimum_memory_required:
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load)))
else:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
if len(models_to_load) == 0:
return

logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
for loaded_model in models_to_load:
to_unload = []
for i in range(len(current_loaded_models)):
if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
for i in to_unload:
current_loaded_models.pop(i).model.detach(unpatch_all=False)

total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)

for loaded_model in models_already_loaded:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)

for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)

for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
models_l = free_memory(minimum_memory_required, device)
logging.info("{} models unloaded.".format(len(models_l)))

for loaded_model in models_to_load:
model = loaded_model.model
Expand All @@ -544,17 +511,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu

cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)


devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_mem = get_free_memory(d)
if free_mem > minimum_memory_required:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
return


def load_model_gpu(model):
return load_models_gpu([model])

Expand All @@ -568,21 +526,35 @@ def loaded_models(only_currently_used=False):
output.append(m.model)
return output

def cleanup_models(keep_clone_weights_loaded=False):

def cleanup_models_gc():
do_gc = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.real_model() is not None and cur.model is None:
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break

if do_gc:
gc.collect()
soft_empty_cache()

for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.real_model() is not None and cur.model is None:
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))



def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
#TODO: very fragile function needs improvement
num_refs = sys.getrefcount(current_loaded_models[i].model)
if num_refs <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
if current_loaded_models[i].real_model() is None:
to_delete = [i] + to_delete

for i in to_delete:
x = current_loaded_models.pop(i)
x.model_unload()
del x

def dtype_size(dtype):
Expand Down
36 changes: 32 additions & 4 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.patches_uuid = uuid.uuid4()
self.parent = None

if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
Expand All @@ -149,6 +150,9 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
if not hasattr(self.model, 'model_lowvram'):
self.model.model_lowvram = False

if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None

def model_size(self):
if self.size > 0:
return self.size
Expand All @@ -172,6 +176,7 @@ def clone(self):
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
return n

def is_clone(self, other):
Expand Down Expand Up @@ -464,6 +469,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid

def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
Expand Down Expand Up @@ -498,6 +504,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
else:
comfy.utils.set_attr_param(self.model, k, bk.weight)

self.model.current_weight_patches_uuid = None
self.backup.clear()

if device_to is not None:
Expand Down Expand Up @@ -568,21 +575,42 @@ def partially_unload(self, device_to, memory_to_free=0):
self.model.model_loaded_weight_memory -= memory_freed
return memory_freed

def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)

self.patch_model(load_weights=False)
full_load = False
if self.model.model_lowvram == False:
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self.model.model_loaded_weight_memory
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
try:
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
except Exception as e:
self.detach()
raise e

return self.model.model_loaded_weight_memory - current_used

def detach(self, unpatch_all=True):
self.model_patches_to(self.offload_device)
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
return self.model

def current_loaded_device(self):
return self.model.device

def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)

def __del__(self):
self.detach(unpatch_all=False)

Loading

0 comments on commit 79d5cea

Please sign in to comment.