From 9e0cc78024c14c4e8cf1f8b4eeddf26336533a5a Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 21 Oct 2025 17:10:16 +1000 Subject: [PATCH 1/6] execution: Roll the UI cache into the outputs Currently the UI cache is parallel to the output cache with expectations of being a content superset of the output cache. At the same time the UI and output cache are maintained completely seperately, making it awkward to free the output cache content without changing the behaviour of the UI cache. There are two actual users (getters) of the UI cache. The first is the case of a direct content hit on the output cache when executing a node. This case is very naturally handled by merging the UI and outputs cache. The second case is the history JSON generation at the end of the prompt. This currently works by asking the cache for all_node_ids and then pulling the cache contents for those nodes. all_node_ids is the nodes of the dynamic prompt. So fold the UI cache into the output cache. The current UI cache setter now writes to a prompt-scope dict. When the output cache is set, just get this value from the dict and tuple up with the outputs. When generating the history, simply iterate prompt-scope dict. This prepares support for more complex caching strategies (like RAM pressure caching) where less than 1 workflow will be cached and it will be desirable to keep the UI cache and output cache in sync. --- comfy_execution/graph.py | 7 ++++++- execution.py | 41 ++++++++++++++++++---------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 341c9735d571..aadb562e6218 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -212,7 +212,12 @@ def cache_link(self, from_node_id, to_node_id): def get_output_cache(self, from_node_id, to_node_id): if not to_node_id in self.execution_cache: return None - return self.execution_cache[to_node_id].get(from_node_id) + value = self.execution_cache[to_node_id].get(from_node_id) + if value is None: + return None + #Write back to the main cache on touch. + self.output_cache.set(from_node_id, value) + return value[0] def cache_update(self, node_id, value): if node_id in self.execution_cache_listeners: diff --git a/execution.py b/execution.py index 78c36a4b0556..cfd22127269e 100644 --- a/execution.py +++ b/execution.py @@ -107,30 +107,24 @@ def __init__(self, cache_type=None, cache_size=None): else: self.init_classic_cache() - self.all = [self.outputs, self.ui, self.objects] + self.all = [self.outputs, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) - self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): self.outputs = NullCache() - #The UI cache is expected to be iterable at the end of each workflow - #so it must cache at least a full workflow. Use Heirachical - self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = NullCache() def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), - "ui": self.ui.recursive_debug_dump(), } return result @@ -393,7 +387,7 @@ def format_value(x): else: return str(x) -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -401,12 +395,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if caches.outputs.get(unique_id) is not None: + cached = caches.outputs.get(unique_id) + if cached is not None: if server.client_id is not None: - cached_output = caches.ui.get(unique_id) or {} - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + cached_ui = cached[1] or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) + if cached[1] is not None: + ui_outputs[unique_id] = cached[1] get_progress_state().finish_progress(unique_id) - execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) + execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -506,7 +503,7 @@ async def await_completion(): asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: - caches.ui.set(unique_id, { + ui_outputs[unique_id] = { "meta": { "node_id": unique_id, "display_node": display_node_id, @@ -514,7 +511,7 @@ async def await_completion(): "real_node_id": real_node_id, }, "output": output_ui - }) + } if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) if has_subgraph: @@ -557,8 +554,8 @@ async def await_completion(): pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) - caches.outputs.set(unique_id, output_data) - execution_list.cache_update(unique_id, output_data) + execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id))) + caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id))) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -685,6 +682,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= broadcast=False) pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() @@ -698,7 +696,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= break assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -713,12 +711,9 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= ui_outputs = {} meta_outputs = {} - all_node_ids = self.caches.ui.all_node_ids() - for node_id in all_node_ids: - ui_info = self.caches.ui.get(node_id) - if ui_info is not None: - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] self.history_result = { "outputs": ui_outputs, "meta": meta_outputs, From 708f002955c82d46b3d0438e5eab1997b90e1d17 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 21 Oct 2025 18:03:33 +1000 Subject: [PATCH 2/6] sd: Implement RAM getter for VAE --- comfy/sd.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248dae1..ad2f3bd449a4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -293,6 +293,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False + self.size = None self.downscale_index_formula = None self.upscale_index_formula = None @@ -595,6 +596,16 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + self.model_size() + + def model_size(self): + if self.size is not None: + return self.size + self.size = comfy.model_management.module_size(self.first_stage_model) + return self.size + + def get_ram_usage(self): + return self.model_size() def throw_exception_if_invalid(self): if self.first_stage_model is None: From d34fb4d76f1ea9cec71b51bddcc89e68a0da4636 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 21 Oct 2025 18:03:33 +1000 Subject: [PATCH 3/6] model_patcher: Implement RAM getter for ModelPatcher --- comfy/model_patcher.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8cff7..deb4aa4c2fc7 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -275,6 +275,9 @@ def model_size(self): self.size = comfy.model_management.module_size(self.model) return self.size + def get_ram_usage(self): + return self.model_size() + def loaded_size(self): return self.model.model_loaded_weight_memory From 0c95f229073c357d6cacd803b1e5980369485173 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 23 Oct 2025 21:24:29 +1000 Subject: [PATCH 4/6] sd: Implement RAM getter for CLIP --- comfy/sd.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index ad2f3bd449a4..b7e1c77a16b7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -143,6 +143,9 @@ def clone(self): n.apply_hooks_to_conds = self.apply_hooks_to_conds return n + def get_ram_usage(self): + return self.patcher.get_ram_usage() + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) From f3f526fcd3618f19a642aab78c958e81b6c0dd3e Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 23 Oct 2025 23:52:37 +1000 Subject: [PATCH 5/6] Implement RAM Pressure cache Implement a cache sensitive to RAM pressure. When RAM headroom drops down below a certain threshold, evict RAM-expensive nodes from the cache. Models and tensors are measured directly for RAM usage. An OOM score is then computed based on the RAM usage of the node. Note the due to indirection through shared objects (like a model patcher), multiple nodes can account the same RAM as their individual usage. The intent is this will free chains of nodes particularly model loaders and associate loras as they all score similar and are sorted in close to each other. Has a bias towards unloading model nodes mid flow while being able to keep results like text encodings and VAE. --- comfy/cli_args.py | 1 + comfy_execution/caching.py | 83 ++++++++++++++++++++++++++++++++++++++ execution.py | 22 +++++++--- main.py | 4 +- 4 files changed, 103 insertions(+), 7 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482e9f..74c8175f947f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -105,6 +105,7 @@ class LatentPreviewMethod(enum.Enum): cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") +cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 566bc3f9c74a..b498f43e7cd7 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,4 +1,9 @@ +import bisect +import gc import itertools +import psutil +import time +import torch from typing import Sequence, Mapping, Dict from comfy_execution.graph import DynamicPrompt from abc import ABC, abstractmethod @@ -188,6 +193,9 @@ def clean_unused(self): self._clean_cache() self._clean_subcaches() + def poll(self, **kwargs): + pass + def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) @@ -276,6 +284,9 @@ def all_node_ids(self): def clean_unused(self): pass + def poll(self, **kwargs): + pass + def get(self, node_id): return None @@ -336,3 +347,75 @@ async def ensure_subcache_for(self, node_id, children_ids): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + + +#Iterating the cache for usage analysis might be expensive, so if we trigger make sure +#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. + +RAM_CACHE_HYSTERESIS = 1.1 + +#This is kinda in GB but not really. It needs to be non-zero for the below heuristic +#and as long as Multi GB models dwarf this it will approximate OOM scoring OK + +RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 + +#Exponential bias towards evicting older workflows so garbage will be taken out +#in constantly changing setups. + +RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 + +class RAMPressureCache(LRUCache): + + def __init__(self, key_class): + super().__init__(key_class, 0) + self.timestamps = {} + + def clean_unused(self): + self._clean_subcaches() + + def set(self, node_id, value): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + super().set(node_id, value) + + def get(self, node_id): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + return super().get(node_id) + + def poll(self, ram_headroom): + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > ram_headroom: + return + gc.collect() + if _ram_gb() > ram_headroom: + return + + clean_list = [] + + for key, (outputs, _), in self.cache.items(): + oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) + + ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE + def scan_list_for_ram_usage(outputs): + nonlocal ram_usage + for output in outputs: + if isinstance(output, list): + scan_list_for_ram_usage(output) + elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': + #score Tensors at a 50% discount for RAM usage as they are likely to + #be high value intermediates + ram_usage += (output.numel() * output.element_size()) * 0.5 + elif hasattr(output, "get_ram_usage"): + ram_usage += output.get_ram_usage() + scan_list_for_ram_usage(outputs) + + oom_score *= ram_usage + #In the case where we have no information on the node ram usage at all, + #break OOM score ties on the last touch timestamp (pure LRU) + bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + + while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + _, _, key = clean_list.pop() + del self.cache[key] + gc.collect() diff --git a/execution.py b/execution.py index cfd22127269e..5b522fcc90e0 100644 --- a/execution.py +++ b/execution.py @@ -21,6 +21,7 @@ NullCache, HierarchicalCache, LRUCache, + RAMPressureCache, ) from comfy_execution.graph import ( DynamicPrompt, @@ -92,16 +93,20 @@ class CacheType(Enum): CLASSIC = 0 LRU = 1 NONE = 2 + RAM_PRESSURE = 3 class CacheSet: - def __init__(self, cache_type=None, cache_size=None): + def __init__(self, cache_type=None, cache_args={}): if cache_type == CacheType.NONE: self.init_null_cache() logging.info("Disabling intermediate node cache.") + elif cache_type == CacheType.RAM_PRESSURE: + cache_ram = cache_args.get("ram", 16.0) + self.init_ram_cache(cache_ram) + logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: - if cache_size is None: - cache_size = 0 + cache_size = cache_args.get("lru", 0) self.init_lru_cache(cache_size) logging.info("Using LRU cache") else: @@ -118,6 +123,10 @@ def init_lru_cache(self, cache_size): self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) + def init_ram_cache(self, min_headroom): + self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.objects = HierarchicalCache(CacheKeySetID) + def init_null_cache(self): self.outputs = NullCache() self.objects = NullCache() @@ -600,14 +609,14 @@ async def await_completion(): return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, cache_type=False, cache_size=None): - self.cache_size = cache_size + def __init__(self, server, cache_type=False, cache_args=None): + self.cache_args = cache_args self.cache_type = cache_type self.server = server self.reset() def reset(self): - self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.status_messages = [] self.success = True @@ -705,6 +714,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) diff --git a/main.py b/main.py index 4b4c5dcc4294..c9e5f72ed25c 100644 --- a/main.py +++ b/main.py @@ -172,10 +172,12 @@ def prompt_worker(q, server_instance): cache_type = execution.CacheType.CLASSIC if args.cache_lru > 0: cache_type = execution.CacheType.LRU + elif args.cache_ram > 0: + cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 From d2b87e674a6445ab4efd764b7abc820c033df913 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 30 Oct 2025 20:40:53 +1000 Subject: [PATCH 6/6] execution: Convert the cache entry to NamedTuple As commented in review. Convert this to a named tuple and abstract away the tuple type completely from graph.py. --- comfy_execution/graph.py | 4 ++-- execution.py | 28 +++++++++++++++++----------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index aadb562e6218..0d811e3546ef 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -209,7 +209,7 @@ def cache_link(self, from_node_id, to_node_id): self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) - def get_output_cache(self, from_node_id, to_node_id): + def get_cache(self, from_node_id, to_node_id): if not to_node_id in self.execution_cache: return None value = self.execution_cache[to_node_id].get(from_node_id) @@ -217,7 +217,7 @@ def get_output_cache(self, from_node_id, to_node_id): return None #Write back to the main cache on touch. self.output_cache.set(from_node_id, value) - return value[0] + return value def cache_update(self, node_id, value): if node_id in self.execution_cache_listeners: diff --git a/execution.py b/execution.py index 5b522fcc90e0..8e099c336849 100644 --- a/execution.py +++ b/execution.py @@ -89,6 +89,11 @@ async def get(self, node_id): return self.is_changed[node_id] +class CacheEntry(NamedTuple): + ui: dict + outputs: list + + class CacheType(Enum): CLASSIC = 0 LRU = 1 @@ -160,14 +165,14 @@ def mark_missing(): if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = execution_list.get_output_cache(input_unique_id, unique_id) - if cached_output is None: + cached = execution_list.get_cache(input_unique_id, unique_id) + if cached is None or cached.outputs is None: mark_missing() continue - if output_index >= len(cached_output): + if output_index >= len(cached.outputs): mark_missing() continue - obj = cached_output[output_index] + obj = cached.outputs[output_index] input_data_all[x] = obj elif input_category is not None: input_data_all[x] = [input_data] @@ -407,10 +412,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cached = caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: - cached_ui = cached[1] or {} + cached_ui = cached.ui or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) - if cached[1] is not None: - ui_outputs[unique_id] = cached[1] + if cached.ui is not None: + ui_outputs[unique_id] = cached.ui get_progress_state().finish_progress(unique_id) execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) @@ -442,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] - for o in node_output: + node_cached = execution_list.get_cache(source_node, unique_id) + for o in node_cached.outputs[source_output]: resolved_output.append(o) else: @@ -563,8 +568,9 @@ async def await_completion(): pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) - execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id))) - caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id))) + cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) + execution_list.cache_update(unique_id, cache_entry) + caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted")