diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 2501169edb5..eb050327dea 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -816,10 +816,57 @@ class GraphExecutionState(BaseModel): # Optional priority; others follow in name order ready_order: list[str] = Field(default_factory=list) indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes") + _iteration_path_cache: dict[str, tuple[int, ...]] = PrivateAttr(default_factory=dict) def _type_key(self, node_obj: BaseInvocation) -> str: return node_obj.__class__.__name__ + def _get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]: + """Best-effort outer->inner iteration indices for an execution node, stopping at collectors.""" + cached = self._iteration_path_cache.get(exec_node_id) + if cached is not None: + return cached + + # Only prepared execution nodes participate; otherwise treat as non-iterated. + source_node_id = self.prepared_source_mapping.get(exec_node_id) + if source_node_id is None: + self._iteration_path_cache[exec_node_id] = () + return () + + # Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak. + it_g = self._iterator_graph(self.graph.nx_graph()) + iterator_sources = [ + n for n in nx.ancestors(it_g, source_node_id) if isinstance(self.graph.get_node(n), IterateInvocation) + ] + + # Order iterators outer->inner via topo order of the iterator graph. + topo = list(nx.topological_sort(it_g)) + topo_index = {n: i for i, n in enumerate(topo)} + iterator_sources.sort(key=lambda n: topo_index.get(n, 0)) + + # Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id. + eg = self.execution_graph.nx_graph() + path: list[int] = [] + for it_src in iterator_sources: + prepared = self.source_prepared_mapping.get(it_src) + if not prepared: + continue + it_exec = next((p for p in prepared if nx.has_path(eg, p, exec_node_id)), None) + if it_exec is None: + continue + it_node = self.execution_graph.nodes.get(it_exec) + if isinstance(it_node, IterateInvocation): + path.append(it_node.index) + + # If this exec node is itself an iterator, include its own index as the innermost element. + node_obj = self.execution_graph.nodes.get(exec_node_id) + if isinstance(node_obj, IterateInvocation): + path.append(node_obj.index) + + result = tuple(path) + self._iteration_path_cache[exec_node_id] = result + return result + def _queue_for(self, cls_name: str) -> Deque[str]: q = self._ready_queues.get(cls_name) if q is None: @@ -843,7 +890,15 @@ def _enqueue_if_ready(self, nid: str) -> None: if self.indegree[nid] != 0 or nid in self.executed: return node_obj = self.execution_graph.nodes[nid] - self._queue_for(self._type_key(node_obj)).append(nid) + q = self._queue_for(self._type_key(node_obj)) + nid_path = self._get_iteration_path(nid) + # Insert in lexicographic outer->inner order; preserve FIFO for equal paths. + for i, existing in enumerate(q): + if self._get_iteration_path(existing) > nid_path: + q.insert(i, nid) + break + else: + q.append(nid) model_config = ConfigDict( json_schema_extra={ @@ -1083,12 +1138,12 @@ def no_unexecuted_iter_ancestors(n: str) -> bool: # Select the correct prepared parents for each iteration # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator - # TODO: Handle a node mapping to none eg = self.execution_graph.nx_graph_flat() prepared_parent_mappings = [ [(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations ] # type: ignore + prepared_parent_mappings = [m for m in prepared_parent_mappings if all(p[1] is not None for p in m)] # Create execution node for each iteration for iteration_mappings in prepared_parent_mappings: @@ -1110,15 +1165,17 @@ def _get_iteration_node( if len(prepared_nodes) == 1: return next(iter(prepared_nodes)) - # Check if the requested node is an iterator - prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) - if prepared_iterator is not None: - return prepared_iterator - # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)] + # If the requested node is an iterator, only accept it if it is compatible with all parent iterators + prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) + if prepared_iterator is not None: + if all(nx.has_path(execution_graph, pit[0], prepared_iterator) for pit in parent_iterators): + return prepared_iterator + return None + return next( (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), None, diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index 381c4c73482..b43698a2428 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -225,3 +225,48 @@ def test_graph_iterate_execution_order(execution_number: int): _ = invoke_next(g) assert _[1].item == "Dinosaur Sushi" _ = invoke_next(g) + + +# Because this tests deterministic ordering, we run it multiple times +@pytest.mark.parametrize("execution_number", range(5)) +def test_graph_nested_iterate_execution_order(execution_number: int): + """ + Validates best-effort in-order execution for nodes expanded under nested iterators. + Expected lexicographic order by (outer_index, inner_index), subject to readiness. + """ + graph = Graph() + + # Outer iterator: [0, 1] + graph.add_node(RangeInvocation(id="outer_range", start=0, stop=2, step=1)) + graph.add_node(IterateInvocation(id="outer_iter")) + + # Inner iterator is derived from the outer item: + # start = outer_item * 10 + # stop = start + 2 => yields 2 items per outer item + graph.add_node(MultiplyInvocation(id="mul10", b=10)) + graph.add_node(AddInvocation(id="stop_plus2", b=2)) + graph.add_node(RangeInvocation(id="inner_range", start=0, stop=1, step=1)) + graph.add_node(IterateInvocation(id="inner_iter")) + + # Observe inner items (they encode outer via start=outer*10) + graph.add_node(AddInvocation(id="sum", b=0)) + + graph.add_edge(create_edge("outer_range", "collection", "outer_iter", "collection")) + graph.add_edge(create_edge("outer_iter", "item", "mul10", "a")) + graph.add_edge(create_edge("mul10", "value", "stop_plus2", "a")) + graph.add_edge(create_edge("mul10", "value", "inner_range", "start")) + graph.add_edge(create_edge("stop_plus2", "value", "inner_range", "stop")) + graph.add_edge(create_edge("inner_range", "collection", "inner_iter", "collection")) + graph.add_edge(create_edge("inner_iter", "item", "sum", "a")) + + g = GraphExecutionState(graph=graph) + sum_values: list[int] = [] + + while True: + n, o = invoke_next(g) + if n is None: + break + if g.prepared_source_mapping[n.id] == "sum": + sum_values.append(o.value) + + assert sum_values == [0, 1, 10, 11]