Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,52 @@ def format_value(x):
else:
return str(x)


def get_route_node_keys(key, prompt):
'''Obtain route node number'''
keys=[]
for k,v in prompt.items():
for k1,v1 in v['inputs'].items():
if type(v1)==list and v1[0]==key:
keys.append(k)
keys=keys+get_route_node_keys(k,prompt)
return keys
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return (True, None, None)

#CountingCycleTail----------start----------
startNum=None
delKeys=[]
backhaul={}
clTypes=['CountingCycleTail','IfExecute']
oldPrompt=None
if class_type in clTypes:
oldPrompt=copy.deepcopy(prompt)

if class_type=='CountingCycleTail':
startNum=prompt[unique_id]['inputs']['total'][0]
inputNum=prompt[unique_id]['inputs']['images'][0]
startInput=prompt[startNum]['inputs']
maxKey = int(sorted(list(prompt.keys()), key=lambda x: int(x))[-1])
delKeys=list(set(get_route_node_keys(startNum,prompt)))
delKeys.append(startNum)
delKeys = list(filter(lambda x: x != inputNum, delKeys))
for key in delKeys:
outputs.pop(key, None)
for i in range(startInput['i']+1,startInput['total'],startInput['stop']):
prompt[str(maxKey+i)]=prompt[inputNum]
prompt[unique_id]['inputs']['images'+str(i)]=[str(maxKey+i),prompt[unique_id]['inputs']['images'][-1]]
if i==startInput['i']+1:
backhaul['images'+str(i)]=prompt[unique_id]['inputs']['images']
else:
backhaul['images'+str(i)]=[str(maxKey+i-startInput['stop']),prompt[unique_id]['inputs']['images'][-1]]

#CountingCycleTail-----------end---------

for x in inputs:
input_data = inputs[x]
Expand All @@ -131,11 +170,28 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
#CountingCycleTail----------start----------
if class_type=='CountingCycleTail' and x !='images' and x.startswith('images'):
if startNum!=None:
prompt[startNum]['inputs']['i']=prompt[startNum]['inputs']['i']+prompt[startNum]['inputs']['stop']
prompt[startNum]['inputs']['images']=backhaul[x]
print(prompt[startNum])
for key in delKeys:
outputs.pop(key, None)
#CountingCycleTail-----------end---------
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result

#If Execute-----------start---------
if class_type=='IfExecute' and x=='ANY':
if outputs[input_unique_id][output_index][0]:
inputs['IF_FALSE']=oldPrompt[unique_id]['inputs']['IF_TRUE']
else:
inputs['IF_TRUE']=oldPrompt[unique_id]['inputs']['IF_FALSE']
#If Execute-----------end---------

input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
Expand All @@ -150,6 +206,13 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute

output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
#Restore prompt----------start----------
if class_type=='CountingCycleTail':
prompt[startNum]['inputs']=oldPrompt[startNum]['inputs']
prompt[unique_id]['inputs']=oldPrompt[unique_id]['inputs']
if class_type=='IfExecute':
prompt[unique_id]['inputs']=oldPrompt[unique_id]['inputs']
#Restore prompt-----------end---------
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
Expand Down
90 changes: 90 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,7 +1739,90 @@ def expand_image(self, image, left, top, right, bottom, feathering):

return (new_image, mask)

class CountingCycleFirst:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"total": ("INT", {"default": 0, "min": 0, "max": 99999}),
"stop": ("INT", {"default": 1, "min": 1, "max": 999}),
"i": ("INT", {"default": 0, "min": 0, "max": 99999}),
}
}
RETURN_TYPES = ("INT","INT","INT","IMAGE",)
RETURN_NAMES = ("total","i","seed",'images',)
FUNCTION = "for_start_fun"

CATEGORY = "utils"

def for_start_fun(self,total,stop,i, **kwargs):
images=kwargs['images'] if 'images' in kwargs else None
random.seed(i)
return (total,i,random.randint(0,sys.maxsize),images,)

class CountingCycleTail:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(cls):

return {
"required": {
"total": ("INT", {"forceInput": True}),
"images": ("IMAGE", ),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ('images',)
FUNCTION = "for_end_fun"

CATEGORY = "utils"

def for_end_fun(self,total,images, **kwargs):
for k,v in kwargs.items():
if k.startswith('images') and v!=None:
if images== None:
images = v
continue
if images.shape[1:] != v.shape[1:]:
v = comfy.utils.common_upscale(v.movedim(-1,1), images.shape[2], images.shape[1], "bilinear", "center").movedim(1,-1)
images = torch.cat((images, v), dim=0)

return (images,)

class AlwaysEqualProxy(str):
def __eq__(self, _):
return True

def __ne__(self, _):
return False

class IfExecute:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ANY": (AlwaysEqualProxy("*"),),
"IF_TRUE": (AlwaysEqualProxy("*"),),
"IF_FALSE": (AlwaysEqualProxy("*"),),
}
}

RETURN_TYPES = (AlwaysEqualProxy("*"),)

RETURN_NAMES = "?"

FUNCTION = "return_based_on_bool"

CATEGORY = "utils"

def return_based_on_bool(self, ANY, IF_TRUE, IF_FALSE):
return (IF_TRUE if ANY else IF_FALSE,)

NODE_CLASS_MAPPINGS = {
"KSampler": KSampler,
"CheckpointLoaderSimple": CheckpointLoaderSimple,
Expand Down Expand Up @@ -1807,6 +1890,10 @@ def expand_image(self, image, left, top, right, bottom, feathering):
"ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly,

"CountingCycleFirst": CountingCycleFirst,
"CountingCycleTail": CountingCycleTail,
"IfExecute": IfExecute,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand Down Expand Up @@ -1866,6 +1953,9 @@ def expand_image(self, image, left, top, right, bottom, feathering):
# _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
"CountingCycleFirst":"Counting cycle first",
"CountingCycleTail": "Counting cycle tail",
"IfExecute": "If Execute",
}

EXTENSION_WEB_DIRS = {}
Expand Down