Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,12 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute

return (True, None, None)

def recursive_will_execute(prompt, outputs, current_item):
def recursive_will_execute(prompt, outputs, current_item, memo={}):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be backwards compatible with any external callers of this function.

unique_id = current_item

if unique_id in memo:
return memo[unique_id]

inputs = prompt[unique_id]['inputs']
will_execute = []
if unique_id in outputs:
Expand All @@ -207,9 +211,10 @@ def recursive_will_execute(prompt, outputs, current_item):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)

return will_execute + [unique_id]
memo[unique_id] = will_execute + [unique_id]
return memo[unique_id]

def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
Expand Down Expand Up @@ -377,7 +382,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):

while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
memo = {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memo should be created outside the lambda so it is reused for entire sorting algorithm

to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]

# This call shouldn't raise anything if there's an error deep in
Expand Down