Skip to content
71 changes: 64 additions & 7 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,10 +816,57 @@ class GraphExecutionState(BaseModel):
# Optional priority; others follow in name order
ready_order: list[str] = Field(default_factory=list)
indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes")
_iteration_path_cache: dict[str, tuple[int, ...]] = PrivateAttr(default_factory=dict)

def _type_key(self, node_obj: BaseInvocation) -> str:
return node_obj.__class__.__name__

def _get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]:
"""Best-effort outer->inner iteration indices for an execution node, stopping at collectors."""
cached = self._iteration_path_cache.get(exec_node_id)
if cached is not None:
return cached

# Only prepared execution nodes participate; otherwise treat as non-iterated.
source_node_id = self.prepared_source_mapping.get(exec_node_id)
if source_node_id is None:
self._iteration_path_cache[exec_node_id] = ()
return ()

# Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak.
it_g = self._iterator_graph(self.graph.nx_graph())
iterator_sources = [
n for n in nx.ancestors(it_g, source_node_id) if isinstance(self.graph.get_node(n), IterateInvocation)
]

# Order iterators outer->inner via topo order of the iterator graph.
topo = list(nx.topological_sort(it_g))
topo_index = {n: i for i, n in enumerate(topo)}
iterator_sources.sort(key=lambda n: topo_index.get(n, 0))

# Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id.
eg = self.execution_graph.nx_graph()
path: list[int] = []
for it_src in iterator_sources:
prepared = self.source_prepared_mapping.get(it_src)
if not prepared:
continue
it_exec = next((p for p in prepared if nx.has_path(eg, p, exec_node_id)), None)
if it_exec is None:
continue
it_node = self.execution_graph.nodes.get(it_exec)
if isinstance(it_node, IterateInvocation):
path.append(it_node.index)

# If this exec node is itself an iterator, include its own index as the innermost element.
node_obj = self.execution_graph.nodes.get(exec_node_id)
if isinstance(node_obj, IterateInvocation):
path.append(node_obj.index)

result = tuple(path)
self._iteration_path_cache[exec_node_id] = result
return result

def _queue_for(self, cls_name: str) -> Deque[str]:
q = self._ready_queues.get(cls_name)
if q is None:
Expand All @@ -843,7 +890,15 @@ def _enqueue_if_ready(self, nid: str) -> None:
if self.indegree[nid] != 0 or nid in self.executed:
return
node_obj = self.execution_graph.nodes[nid]
self._queue_for(self._type_key(node_obj)).append(nid)
q = self._queue_for(self._type_key(node_obj))
nid_path = self._get_iteration_path(nid)
# Insert in lexicographic outer->inner order; preserve FIFO for equal paths.
for i, existing in enumerate(q):
if self._get_iteration_path(existing) > nid_path:
q.insert(i, nid)
break
else:
q.append(nid)

model_config = ConfigDict(
json_schema_extra={
Expand Down Expand Up @@ -1083,12 +1138,12 @@ def no_unexecuted_iter_ancestors(n: str) -> bool:

# Select the correct prepared parents for each iteration
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
for it in iterator_node_prepared_combinations
] # type: ignore
prepared_parent_mappings = [m for m in prepared_parent_mappings if all(p[1] is not None for p in m)]

# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
Expand All @@ -1110,15 +1165,17 @@ def _get_iteration_node(
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))

# Check if the requested node is an iterator
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
if prepared_iterator is not None:
return prepared_iterator

# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]

# If the requested node is an iterator, only accept it if it is compatible with all parent iterators
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
if prepared_iterator is not None:
if all(nx.has_path(execution_graph, pit[0], prepared_iterator) for pit in parent_iterators):
return prepared_iterator
return None

return next(
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
None,
Expand Down
45 changes: 45 additions & 0 deletions tests/test_graph_execution_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,48 @@ def test_graph_iterate_execution_order(execution_number: int):
_ = invoke_next(g)
assert _[1].item == "Dinosaur Sushi"
_ = invoke_next(g)


# Because this tests deterministic ordering, we run it multiple times
@pytest.mark.parametrize("execution_number", range(5))
def test_graph_nested_iterate_execution_order(execution_number: int):
"""
Validates best-effort in-order execution for nodes expanded under nested iterators.
Expected lexicographic order by (outer_index, inner_index), subject to readiness.
"""
graph = Graph()

# Outer iterator: [0, 1]
graph.add_node(RangeInvocation(id="outer_range", start=0, stop=2, step=1))
graph.add_node(IterateInvocation(id="outer_iter"))

# Inner iterator is derived from the outer item:
# start = outer_item * 10
# stop = start + 2 => yields 2 items per outer item
graph.add_node(MultiplyInvocation(id="mul10", b=10))
graph.add_node(AddInvocation(id="stop_plus2", b=2))
graph.add_node(RangeInvocation(id="inner_range", start=0, stop=1, step=1))
graph.add_node(IterateInvocation(id="inner_iter"))

# Observe inner items (they encode outer via start=outer*10)
graph.add_node(AddInvocation(id="sum", b=0))

graph.add_edge(create_edge("outer_range", "collection", "outer_iter", "collection"))
graph.add_edge(create_edge("outer_iter", "item", "mul10", "a"))
graph.add_edge(create_edge("mul10", "value", "stop_plus2", "a"))
graph.add_edge(create_edge("mul10", "value", "inner_range", "start"))
graph.add_edge(create_edge("stop_plus2", "value", "inner_range", "stop"))
graph.add_edge(create_edge("inner_range", "collection", "inner_iter", "collection"))
graph.add_edge(create_edge("inner_iter", "item", "sum", "a"))

g = GraphExecutionState(graph=graph)
sum_values: list[int] = []

while True:
n, o = invoke_next(g)
if n is None:
break
if g.prepared_source_mapping[n.id] == "sum":
sum_values.append(o.value)

assert sum_values == [0, 1, 10, 11]