Skip to content

Commit 513b0c4

Browse files
authored
Add RAM Pressure cache mode (#10454)
* execution: Roll the UI cache into the outputs Currently the UI cache is parallel to the output cache with expectations of being a content superset of the output cache. At the same time the UI and output cache are maintained completely seperately, making it awkward to free the output cache content without changing the behaviour of the UI cache. There are two actual users (getters) of the UI cache. The first is the case of a direct content hit on the output cache when executing a node. This case is very naturally handled by merging the UI and outputs cache. The second case is the history JSON generation at the end of the prompt. This currently works by asking the cache for all_node_ids and then pulling the cache contents for those nodes. all_node_ids is the nodes of the dynamic prompt. So fold the UI cache into the output cache. The current UI cache setter now writes to a prompt-scope dict. When the output cache is set, just get this value from the dict and tuple up with the outputs. When generating the history, simply iterate prompt-scope dict. This prepares support for more complex caching strategies (like RAM pressure caching) where less than 1 workflow will be cached and it will be desirable to keep the UI cache and output cache in sync. * sd: Implement RAM getter for VAE * model_patcher: Implement RAM getter for ModelPatcher * sd: Implement RAM getter for CLIP * Implement RAM Pressure cache Implement a cache sensitive to RAM pressure. When RAM headroom drops down below a certain threshold, evict RAM-expensive nodes from the cache. Models and tensors are measured directly for RAM usage. An OOM score is then computed based on the RAM usage of the node. Note the due to indirection through shared objects (like a model patcher), multiple nodes can account the same RAM as their individual usage. The intent is this will free chains of nodes particularly model loaders and associate loras as they all score similar and are sorted in close to each other. Has a bias towards unloading model nodes mid flow while being able to keep results like text encodings and VAE. * execution: Convert the cache entry to NamedTuple As commented in review. Convert this to a named tuple and abstract away the tuple type completely from graph.py.
1 parent dfac946 commit 513b0c4

File tree

7 files changed

+157
-38
lines changed

7 files changed

+157
-38
lines changed

comfy/cli_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class LatentPreviewMethod(enum.Enum):
105105
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
106106
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
107107
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
108+
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
108109

109110
attn_group = parser.add_mutually_exclusive_group()
110111
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

comfy/model_patcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ def model_size(self):
276276
self.size = comfy.model_management.module_size(self.model)
277277
return self.size
278278

279+
def get_ram_usage(self):
280+
return self.model_size()
281+
279282
def loaded_size(self):
280283
return self.model.model_loaded_weight_memory
281284

comfy/sd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def clone(self):
143143
n.apply_hooks_to_conds = self.apply_hooks_to_conds
144144
return n
145145

146+
def get_ram_usage(self):
147+
return self.patcher.get_ram_usage()
148+
146149
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
147150
return self.patcher.add_patches(patches, strength_patch, strength_model)
148151

@@ -293,6 +296,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
293296
self.working_dtypes = [torch.bfloat16, torch.float32]
294297
self.disable_offload = False
295298
self.not_video = False
299+
self.size = None
296300

297301
self.downscale_index_formula = None
298302
self.upscale_index_formula = None
@@ -595,6 +599,16 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
595599

596600
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
597601
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
602+
self.model_size()
603+
604+
def model_size(self):
605+
if self.size is not None:
606+
return self.size
607+
self.size = comfy.model_management.module_size(self.first_stage_model)
608+
return self.size
609+
610+
def get_ram_usage(self):
611+
return self.model_size()
598612

599613
def throw_exception_if_invalid(self):
600614
if self.first_stage_model is None:

comfy_execution/caching.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
import bisect
2+
import gc
13
import itertools
4+
import psutil
5+
import time
6+
import torch
27
from typing import Sequence, Mapping, Dict
38
from comfy_execution.graph import DynamicPrompt
49
from abc import ABC, abstractmethod
@@ -188,6 +193,9 @@ def clean_unused(self):
188193
self._clean_cache()
189194
self._clean_subcaches()
190195

196+
def poll(self, **kwargs):
197+
pass
198+
191199
def _set_immediate(self, node_id, value):
192200
assert self.initialized
193201
cache_key = self.cache_key_set.get_data_key(node_id)
@@ -276,6 +284,9 @@ def all_node_ids(self):
276284
def clean_unused(self):
277285
pass
278286

287+
def poll(self, **kwargs):
288+
pass
289+
279290
def get(self, node_id):
280291
return None
281292

@@ -336,3 +347,75 @@ async def ensure_subcache_for(self, node_id, children_ids):
336347
self._mark_used(child_id)
337348
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
338349
return self
350+
351+
352+
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
353+
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
354+
355+
RAM_CACHE_HYSTERESIS = 1.1
356+
357+
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
358+
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
359+
360+
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
361+
362+
#Exponential bias towards evicting older workflows so garbage will be taken out
363+
#in constantly changing setups.
364+
365+
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
366+
367+
class RAMPressureCache(LRUCache):
368+
369+
def __init__(self, key_class):
370+
super().__init__(key_class, 0)
371+
self.timestamps = {}
372+
373+
def clean_unused(self):
374+
self._clean_subcaches()
375+
376+
def set(self, node_id, value):
377+
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
378+
super().set(node_id, value)
379+
380+
def get(self, node_id):
381+
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
382+
return super().get(node_id)
383+
384+
def poll(self, ram_headroom):
385+
def _ram_gb():
386+
return psutil.virtual_memory().available / (1024**3)
387+
388+
if _ram_gb() > ram_headroom:
389+
return
390+
gc.collect()
391+
if _ram_gb() > ram_headroom:
392+
return
393+
394+
clean_list = []
395+
396+
for key, (outputs, _), in self.cache.items():
397+
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
398+
399+
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
400+
def scan_list_for_ram_usage(outputs):
401+
nonlocal ram_usage
402+
for output in outputs:
403+
if isinstance(output, list):
404+
scan_list_for_ram_usage(output)
405+
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
406+
#score Tensors at a 50% discount for RAM usage as they are likely to
407+
#be high value intermediates
408+
ram_usage += (output.numel() * output.element_size()) * 0.5
409+
elif hasattr(output, "get_ram_usage"):
410+
ram_usage += output.get_ram_usage()
411+
scan_list_for_ram_usage(outputs)
412+
413+
oom_score *= ram_usage
414+
#In the case where we have no information on the node ram usage at all,
415+
#break OOM score ties on the last touch timestamp (pure LRU)
416+
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
417+
418+
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
419+
_, _, key = clean_list.pop()
420+
del self.cache[key]
421+
gc.collect()

comfy_execution/graph.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,15 @@ def cache_link(self, from_node_id, to_node_id):
209209
self.execution_cache_listeners[from_node_id] = set()
210210
self.execution_cache_listeners[from_node_id].add(to_node_id)
211211

212-
def get_output_cache(self, from_node_id, to_node_id):
212+
def get_cache(self, from_node_id, to_node_id):
213213
if not to_node_id in self.execution_cache:
214214
return None
215-
return self.execution_cache[to_node_id].get(from_node_id)
215+
value = self.execution_cache[to_node_id].get(from_node_id)
216+
if value is None:
217+
return None
218+
#Write back to the main cache on touch.
219+
self.output_cache.set(from_node_id, value)
220+
return value
216221

217222
def cache_update(self, node_id, value):
218223
if node_id in self.execution_cache_listeners:

execution.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
NullCache,
2222
HierarchicalCache,
2323
LRUCache,
24+
RAMPressureCache,
2425
)
2526
from comfy_execution.graph import (
2627
DynamicPrompt,
@@ -88,49 +89,56 @@ async def get(self, node_id):
8889
return self.is_changed[node_id]
8990

9091

92+
class CacheEntry(NamedTuple):
93+
ui: dict
94+
outputs: list
95+
96+
9197
class CacheType(Enum):
9298
CLASSIC = 0
9399
LRU = 1
94100
NONE = 2
101+
RAM_PRESSURE = 3
95102

96103

97104
class CacheSet:
98-
def __init__(self, cache_type=None, cache_size=None):
105+
def __init__(self, cache_type=None, cache_args={}):
99106
if cache_type == CacheType.NONE:
100107
self.init_null_cache()
101108
logging.info("Disabling intermediate node cache.")
109+
elif cache_type == CacheType.RAM_PRESSURE:
110+
cache_ram = cache_args.get("ram", 16.0)
111+
self.init_ram_cache(cache_ram)
112+
logging.info("Using RAM pressure cache.")
102113
elif cache_type == CacheType.LRU:
103-
if cache_size is None:
104-
cache_size = 0
114+
cache_size = cache_args.get("lru", 0)
105115
self.init_lru_cache(cache_size)
106116
logging.info("Using LRU cache")
107117
else:
108118
self.init_classic_cache()
109119

110-
self.all = [self.outputs, self.ui, self.objects]
120+
self.all = [self.outputs, self.objects]
111121

112122
# Performs like the old cache -- dump data ASAP
113123
def init_classic_cache(self):
114124
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
115-
self.ui = HierarchicalCache(CacheKeySetInputSignature)
116125
self.objects = HierarchicalCache(CacheKeySetID)
117126

118127
def init_lru_cache(self, cache_size):
119128
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
120-
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
129+
self.objects = HierarchicalCache(CacheKeySetID)
130+
131+
def init_ram_cache(self, min_headroom):
132+
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
121133
self.objects = HierarchicalCache(CacheKeySetID)
122134

123135
def init_null_cache(self):
124136
self.outputs = NullCache()
125-
#The UI cache is expected to be iterable at the end of each workflow
126-
#so it must cache at least a full workflow. Use Heirachical
127-
self.ui = HierarchicalCache(CacheKeySetInputSignature)
128137
self.objects = NullCache()
129138

130139
def recursive_debug_dump(self):
131140
result = {
132141
"outputs": self.outputs.recursive_debug_dump(),
133-
"ui": self.ui.recursive_debug_dump(),
134142
}
135143
return result
136144

@@ -157,14 +165,14 @@ def mark_missing():
157165
if execution_list is None:
158166
mark_missing()
159167
continue # This might be a lazily-evaluated input
160-
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
161-
if cached_output is None:
168+
cached = execution_list.get_cache(input_unique_id, unique_id)
169+
if cached is None or cached.outputs is None:
162170
mark_missing()
163171
continue
164-
if output_index >= len(cached_output):
172+
if output_index >= len(cached.outputs):
165173
mark_missing()
166174
continue
167-
obj = cached_output[output_index]
175+
obj = cached.outputs[output_index]
168176
input_data_all[x] = obj
169177
elif input_category is not None:
170178
input_data_all[x] = [input_data]
@@ -393,20 +401,23 @@ def format_value(x):
393401
else:
394402
return str(x)
395403

396-
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
404+
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
397405
unique_id = current_item
398406
real_node_id = dynprompt.get_real_node_id(unique_id)
399407
display_node_id = dynprompt.get_display_node_id(unique_id)
400408
parent_node_id = dynprompt.get_parent_node_id(unique_id)
401409
inputs = dynprompt.get_node(unique_id)['inputs']
402410
class_type = dynprompt.get_node(unique_id)['class_type']
403411
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
404-
if caches.outputs.get(unique_id) is not None:
412+
cached = caches.outputs.get(unique_id)
413+
if cached is not None:
405414
if server.client_id is not None:
406-
cached_output = caches.ui.get(unique_id) or {}
407-
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
415+
cached_ui = cached.ui or {}
416+
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
417+
if cached.ui is not None:
418+
ui_outputs[unique_id] = cached.ui
408419
get_progress_state().finish_progress(unique_id)
409-
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
420+
execution_list.cache_update(unique_id, cached)
410421
return (ExecutionResult.SUCCESS, None, None)
411422

412423
input_data_all = None
@@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
436447
for r in result:
437448
if is_link(r):
438449
source_node, source_output = r[0], r[1]
439-
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
440-
for o in node_output:
450+
node_cached = execution_list.get_cache(source_node, unique_id)
451+
for o in node_cached.outputs[source_output]:
441452
resolved_output.append(o)
442453

443454
else:
@@ -507,15 +518,15 @@ async def await_completion():
507518
asyncio.create_task(await_completion())
508519
return (ExecutionResult.PENDING, None, None)
509520
if len(output_ui) > 0:
510-
caches.ui.set(unique_id, {
521+
ui_outputs[unique_id] = {
511522
"meta": {
512523
"node_id": unique_id,
513524
"display_node": display_node_id,
514525
"parent_node": parent_node_id,
515526
"real_node_id": real_node_id,
516527
},
517528
"output": output_ui
518-
})
529+
}
519530
if server.client_id is not None:
520531
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
521532
if has_subgraph:
@@ -554,8 +565,9 @@ async def await_completion():
554565
pending_subgraph_results[unique_id] = cached_outputs
555566
return (ExecutionResult.PENDING, None, None)
556567

557-
caches.outputs.set(unique_id, output_data)
558-
execution_list.cache_update(unique_id, output_data)
568+
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
569+
execution_list.cache_update(unique_id, cache_entry)
570+
caches.outputs.set(unique_id, cache_entry)
559571

560572
except comfy.model_management.InterruptProcessingException as iex:
561573
logging.info("Processing interrupted")
@@ -600,14 +612,14 @@ async def await_completion():
600612
return (ExecutionResult.SUCCESS, None, None)
601613

602614
class PromptExecutor:
603-
def __init__(self, server, cache_type=False, cache_size=None):
604-
self.cache_size = cache_size
615+
def __init__(self, server, cache_type=False, cache_args=None):
616+
self.cache_args = cache_args
605617
self.cache_type = cache_type
606618
self.server = server
607619
self.reset()
608620

609621
def reset(self):
610-
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
622+
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
611623
self.status_messages = []
612624
self.success = True
613625

@@ -682,6 +694,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
682694
broadcast=False)
683695
pending_subgraph_results = {}
684696
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
697+
ui_node_outputs = {}
685698
executed = set()
686699
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
687700
current_outputs = self.caches.outputs.all_node_ids()
@@ -695,7 +708,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
695708
break
696709

697710
assert node_id is not None, "Node ID should not be None at this point"
698-
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
711+
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
699712
self.success = result != ExecutionResult.FAILURE
700713
if result == ExecutionResult.FAILURE:
701714
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@@ -704,18 +717,16 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
704717
execution_list.unstage_node_execution()
705718
else: # result == ExecutionResult.SUCCESS:
706719
execution_list.complete_node_execution()
720+
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
707721
else:
708722
# Only execute when the while-loop ends without break
709723
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
710724

711725
ui_outputs = {}
712726
meta_outputs = {}
713-
all_node_ids = self.caches.ui.all_node_ids()
714-
for node_id in all_node_ids:
715-
ui_info = self.caches.ui.get(node_id)
716-
if ui_info is not None:
717-
ui_outputs[node_id] = ui_info["output"]
718-
meta_outputs[node_id] = ui_info["meta"]
727+
for node_id, ui_info in ui_node_outputs.items():
728+
ui_outputs[node_id] = ui_info["output"]
729+
meta_outputs[node_id] = ui_info["meta"]
719730
self.history_result = {
720731
"outputs": ui_outputs,
721732
"meta": meta_outputs,

0 commit comments

Comments
 (0)