Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,14 +577,14 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()

def free_memory(memory_required, device, keep_loaded=[]):
def free_memory(memory_required, device, keep_loaded=[], loaded_models=current_loaded_models):
cleanup_models_gc()
unloaded_model = []
can_unload = []
unloaded_models = []

for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
for i in range(len(loaded_models) -1, -1, -1):
shift_model = loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
Expand All @@ -598,12 +598,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
logging.info(f"Unloading {loaded_models[i].model.model.__class__.__name__}")
if loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)

for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
unloaded_models.append(loaded_models.pop(i))

if len(unloaded_model) > 0:
soft_empty_cache()
Expand Down Expand Up @@ -634,6 +634,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models = models_temp

models_to_load = []
models_to_reload = []

for x in models:
loaded_model = LoadedModel(x)
Expand All @@ -645,17 +646,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
loaded.currently_used = True
models_to_load.append(loaded)
models_to_reload.append(loaded)
else:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)

models_to_load += models_to_reload

for loaded_model in models_to_load:
to_unload = []
for i in range(len(current_loaded_models)):
if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if not current_loaded_models[i] in models_to_reload:
models_to_reload.append(current_loaded_models[i])
for i in to_unload:
model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False)
Expand All @@ -665,16 +670,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)

def free_memory_required(vram, device, models_to_reload):
if get_free_memory(device) < vram:
models_unloaded = free_memory(vram, device)
if len(models_unloaded):
logging.info("{} idle models unloaded.".format(len(models_unloaded)))

models_unloaded = free_memory(vram, device, loaded_models=models_to_reload)
if len(models_unloaded):
logging.info("{} active models unloaded for increased offloading.".format(len(models_unloaded)))
for unloaded_model in models_unloaded:
if unloaded_model in models_to_load:
unloaded_model.currently_used = True

for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
free_memory_required(total_memory_required[device] * 1.1 + extra_mem, device, models_to_reload)

for device in total_memory_required:
if device != torch.device("cpu"):
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
models_l = free_memory(minimum_memory_required, device)
logging.info("{} models unloaded.".format(len(models_l)))
free_memory_required(minimum_memory_required, device, models_to_reload)

for loaded_model in models_to_load:
model = loaded_model.model
Expand Down