Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 20 additions & 154 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,26 @@ async def ensure_subcache_for(self, node_id, children_ids):
assert cache is not None
return await cache._ensure_subcache(node_id, children_ids)

class NullCache:

async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
pass

def all_node_ids(self):
return []

def clean_unused(self):
pass

def get(self, node_id):
return None

def set(self, node_id, value):
pass

async def ensure_subcache_for(self, node_id, children_ids):
return self

class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
Expand Down Expand Up @@ -316,157 +336,3 @@ async def ensure_subcache_for(self, node_id, children_ids):
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self


class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""

def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.

Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed

async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.

Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()

# Call the parent method to initialize the cache with the new prompt
await super().set_prompt(dynprompt, node_ids, is_changed_cache)

# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)

def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.

Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()

for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)

def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.

Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)

def get(self, node_id):
"""
Retrieve the cached value for a node.

Args:
node_id: The ID of the node to retrieve.

Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)

async def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.

Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.

Returns:
The subcache object for the node.
"""
subcache = await super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache

def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.

Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)

def _remove_node(self, node_id):
"""
Remove a node from the cache.

Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]

def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass

def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.

Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result
31 changes: 29 additions & 2 deletions comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
continue
_, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
if (include_lazy or not is_lazy):
if not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))

for link in links:
Expand Down Expand Up @@ -194,10 +195,34 @@ def __init__(self, dynprompt, output_cache):
super().__init__(dynprompt)
self.output_cache = output_cache
self.staged_node_id = None
self.execution_cache = {}
self.execution_cache_listeners = {}

def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None

def cache_link(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
if not from_node_id in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)

def get_output_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
return self.execution_cache[to_node_id].get(from_node_id)

def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:
for to_node_id in self.execution_cache_listeners[node_id]:
self.execution_cache[to_node_id][node_id] = value

def add_strong_link(self, from_node_id, from_socket, to_node_id):
super().add_strong_link(from_node_id, from_socket, to_node_id)
self.cache_link(from_node_id, to_node_id)

async def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
Expand Down Expand Up @@ -277,6 +302,8 @@ def unstage_node_execution(self):
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.execution_cache.pop(node_id, None)
self.execution_cache_listeners.pop(node_id, None)
self.staged_node_id = None

def get_nodes_in_cycle(self):
Expand Down
34 changes: 20 additions & 14 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
BasicCache,
CacheKeySetID,
CacheKeySetInputSignature,
DependencyAwareCache,
NullCache,
HierarchicalCache,
LRUCache,
)
Expand Down Expand Up @@ -91,13 +91,13 @@ async def get(self, node_id):
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
NONE = 2


class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
Expand All @@ -120,11 +120,12 @@ def init_lru_cache(self, cache_size):
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)

# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def init_null_cache(self):
self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache()

def recursive_debug_dump(self):
result = {
Expand All @@ -135,7 +136,7 @@ def recursive_debug_dump(self):

SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")

def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
Expand All @@ -153,10 +154,10 @@ def mark_missing():
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if outputs is None:
if execution_list is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
if cached_output is None:
mark_missing()
continue
Expand Down Expand Up @@ -405,6 +406,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
return (ExecutionResult.SUCCESS, None, None)

input_data_all = None
Expand Down Expand Up @@ -434,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = caches.outputs.get(source_node)[source_output]
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
for o in node_output:
resolved_output.append(o)

Expand All @@ -446,7 +448,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
Expand Down Expand Up @@ -549,11 +551,15 @@ async def await_completion():
subcache.clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
execution_list.cache_link(node_id, unique_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)

caches.outputs.set(unique_id, output_data)
execution_list.cache_update(unique_id, output_data)

except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def prompt_worker(q, server_instance):
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
cache_type = execution.CacheType.NONE

e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
last_gc_collect = 0
Expand Down
Loading