Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class LatentPreviewMethod(enum.Enum):
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
Expand Down
3 changes: 3 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def model_size(self):
self.size = comfy.model_management.module_size(self.model)
return self.size

def get_ram_usage(self):
return self.model_size()

def loaded_size(self):
return self.model.model_loaded_weight_memory

Expand Down
14 changes: 14 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def clone(self):
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n

def get_ram_usage(self):
return self.patcher.get_ram_usage()

def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)

Expand Down Expand Up @@ -293,6 +296,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None

self.downscale_index_formula = None
self.upscale_index_formula = None
Expand Down Expand Up @@ -595,6 +599,16 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):

self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()

def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size

def get_ram_usage(self):
return self.model_size()

def throw_exception_if_invalid(self):
if self.first_stage_model is None:
Expand Down
83 changes: 83 additions & 0 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import bisect
import gc
import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -188,6 +193,9 @@ def clean_unused(self):
self._clean_cache()
self._clean_subcaches()

def poll(self, **kwargs):
pass

def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
Expand Down Expand Up @@ -276,6 +284,9 @@ def all_node_ids(self):
def clean_unused(self):
pass

def poll(self, **kwargs):
pass

def get(self, node_id):
return None

Expand Down Expand Up @@ -336,3 +347,75 @@ async def ensure_subcache_for(self, node_id, children_ids):
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self


#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.

RAM_CACHE_HYSTERESIS = 1.1

#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK

RAM_CACHE_DEFAULT_RAM_USAGE = 0.1

#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.

RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3

class RAMPressureCache(LRUCache):

def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}

def clean_unused(self):
self._clean_subcaches()

def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)

def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)

def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)

if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return

clean_list = []

for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])

ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)

oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))

while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()
7 changes: 6 additions & 1 deletion comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ def cache_link(self, from_node_id, to_node_id):
def get_output_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
return self.execution_cache[to_node_id].get(from_node_id)
value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value[0]

def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:
Expand Down
63 changes: 34 additions & 29 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NullCache,
HierarchicalCache,
LRUCache,
RAMPressureCache,
)
from comfy_execution.graph import (
DynamicPrompt,
Expand Down Expand Up @@ -92,45 +93,47 @@ class CacheType(Enum):
CLASSIC = 0
LRU = 1
NONE = 2
RAM_PRESSURE = 3


class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
cache_size = cache_args.get("lru", 0)
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()

self.all = [self.outputs, self.ui, self.objects]
self.all = [self.outputs, self.objects]

# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)

def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)

def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)

def init_null_cache(self):
self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache()

def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result

Expand Down Expand Up @@ -393,20 +396,23 @@ def format_value(x):
else:
return str(x)

async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
cached_ui = cached[1] or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached[1] is not None:
ui_outputs[unique_id] = cached[1]
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)

input_data_all = None
Expand Down Expand Up @@ -506,15 +512,15 @@ async def await_completion():
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
ui_outputs[unique_id] = {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
"parent_node": parent_node_id,
"real_node_id": real_node_id,
},
"output": output_ui
})
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
Expand Down Expand Up @@ -557,8 +563,8 @@ async def await_completion():
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)

caches.outputs.set(unique_id, output_data)
execution_list.cache_update(unique_id, output_data)
execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id)))
caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id)))

except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
Expand Down Expand Up @@ -603,14 +609,14 @@ async def await_completion():
return (ExecutionResult.SUCCESS, None, None)

class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
def __init__(self, server, cache_type=False, cache_args=None):
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()

def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
self.success = True

Expand Down Expand Up @@ -685,6 +691,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
Expand All @@ -698,7 +705,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
break

assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
Expand All @@ -707,18 +714,16 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)

ui_outputs = {}
meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
Expand Down
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.NONE

e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
Expand Down