diff --git a/execution.py b/execution.py index 5f8ea835c78f..0a98ed43e52d 100644 --- a/execution.py +++ b/execution.py @@ -116,6 +116,16 @@ def format_value(x): else: return str(x) + +def get_route_node_keys(key, prompt): + '''Obtain route node number''' + keys=[] + for k,v in prompt.items(): + for k1,v1 in v['inputs'].items(): + if type(v1)==list and v1[0]==key: + keys.append(k) + keys=keys+get_route_node_keys(k,prompt) + return keys def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -123,6 +133,35 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: return (True, None, None) + + #CountingCycleTail----------start---------- + startNum=None + delKeys=[] + backhaul={} + clTypes=['CountingCycleTail','IfExecute'] + oldPrompt=None + if class_type in clTypes: + oldPrompt=copy.deepcopy(prompt) + + if class_type=='CountingCycleTail': + startNum=prompt[unique_id]['inputs']['total'][0] + inputNum=prompt[unique_id]['inputs']['images'][0] + startInput=prompt[startNum]['inputs'] + maxKey = int(sorted(list(prompt.keys()), key=lambda x: int(x))[-1]) + delKeys=list(set(get_route_node_keys(startNum,prompt))) + delKeys.append(startNum) + delKeys = list(filter(lambda x: x != inputNum, delKeys)) + for key in delKeys: + outputs.pop(key, None) + for i in range(startInput['i']+1,startInput['total'],startInput['stop']): + prompt[str(maxKey+i)]=prompt[inputNum] + prompt[unique_id]['inputs']['images'+str(i)]=[str(maxKey+i),prompt[unique_id]['inputs']['images'][-1]] + if i==startInput['i']+1: + backhaul['images'+str(i)]=prompt[unique_id]['inputs']['images'] + else: + backhaul['images'+str(i)]=[str(maxKey+i-startInput['stop']),prompt[unique_id]['inputs']['images'][-1]] + + #CountingCycleTail-----------end--------- for x in inputs: input_data = inputs[x] @@ -131,11 +170,28 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: + #CountingCycleTail----------start---------- + if class_type=='CountingCycleTail' and x !='images' and x.startswith('images'): + if startNum!=None: + prompt[startNum]['inputs']['i']=prompt[startNum]['inputs']['i']+prompt[startNum]['inputs']['stop'] + prompt[startNum]['inputs']['images']=backhaul[x] + print(prompt[startNum]) + for key in delKeys: + outputs.pop(key, None) + #CountingCycleTail-----------end--------- result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) if result[0] is not True: # Another node failed further upstream return result + #If Execute-----------start--------- + if class_type=='IfExecute' and x=='ANY': + if outputs[input_unique_id][output_index][0]: + inputs['IF_FALSE']=oldPrompt[unique_id]['inputs']['IF_TRUE'] + else: + inputs['IF_TRUE']=oldPrompt[unique_id]['inputs']['IF_FALSE'] + #If Execute-----------end--------- + input_data_all = None try: input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) @@ -150,6 +206,13 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data + #Restore prompt----------start---------- + if class_type=='CountingCycleTail': + prompt[startNum]['inputs']=oldPrompt[startNum]['inputs'] + prompt[unique_id]['inputs']=oldPrompt[unique_id]['inputs'] + if class_type=='IfExecute': + prompt[unique_id]['inputs']=oldPrompt[unique_id]['inputs'] + #Restore prompt-----------end--------- if len(output_ui) > 0: outputs_ui[unique_id] = output_ui if server.client_id is not None: diff --git a/nodes.py b/nodes.py index 453f6e60656f..59406f2b6a1c 100644 --- a/nodes.py +++ b/nodes.py @@ -1739,7 +1739,90 @@ def expand_image(self, image, left, top, right, bottom, feathering): return (new_image, mask) +class CountingCycleFirst: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "total": ("INT", {"default": 0, "min": 0, "max": 99999}), + "stop": ("INT", {"default": 1, "min": 1, "max": 999}), + "i": ("INT", {"default": 0, "min": 0, "max": 99999}), + } + } + RETURN_TYPES = ("INT","INT","INT","IMAGE",) + RETURN_NAMES = ("total","i","seed",'images',) + FUNCTION = "for_start_fun" + + CATEGORY = "utils" + + def for_start_fun(self,total,stop,i, **kwargs): + images=kwargs['images'] if 'images' in kwargs else None + random.seed(i) + return (total,i,random.randint(0,sys.maxsize),images,) + +class CountingCycleTail: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + + return { + "required": { + "total": ("INT", {"forceInput": True}), + "images": ("IMAGE", ), + } + } + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ('images',) + FUNCTION = "for_end_fun" + + CATEGORY = "utils" + + def for_end_fun(self,total,images, **kwargs): + for k,v in kwargs.items(): + if k.startswith('images') and v!=None: + if images== None: + images = v + continue + if images.shape[1:] != v.shape[1:]: + v = comfy.utils.common_upscale(v.movedim(-1,1), images.shape[2], images.shape[1], "bilinear", "center").movedim(1,-1) + images = torch.cat((images, v), dim=0) + + return (images,) + +class AlwaysEqualProxy(str): + def __eq__(self, _): + return True + + def __ne__(self, _): + return False +class IfExecute: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ANY": (AlwaysEqualProxy("*"),), + "IF_TRUE": (AlwaysEqualProxy("*"),), + "IF_FALSE": (AlwaysEqualProxy("*"),), + } + } + + RETURN_TYPES = (AlwaysEqualProxy("*"),) + + RETURN_NAMES = "?" + + FUNCTION = "return_based_on_bool" + + CATEGORY = "utils" + + def return_based_on_bool(self, ANY, IF_TRUE, IF_FALSE): + return (IF_TRUE if ANY else IF_FALSE,) + NODE_CLASS_MAPPINGS = { "KSampler": KSampler, "CheckpointLoaderSimple": CheckpointLoaderSimple, @@ -1807,6 +1890,10 @@ def expand_image(self, image, left, top, right, bottom, feathering): "ConditioningZeroOut": ConditioningZeroOut, "ConditioningSetTimestepRange": ConditioningSetTimestepRange, "LoraLoaderModelOnly": LoraLoaderModelOnly, + + "CountingCycleFirst": CountingCycleFirst, + "CountingCycleTail": CountingCycleTail, + "IfExecute": IfExecute, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1866,6 +1953,9 @@ def expand_image(self, image, left, top, right, bottom, feathering): # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", + "CountingCycleFirst":"Counting cycle first", + "CountingCycleTail": "Counting cycle tail", + "IfExecute": "If Execute", } EXTENSION_WEB_DIRS = {}