diff --git a/comfy/controlnet.py b/comfy/controlnet.py index a44f3725e80..e6a0d1e5976 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,6 +36,10 @@ import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.cldm.dit_embedder +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.hooks import HookGroup + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -78,6 +82,7 @@ def __init__(self): self.concat_mask = False self.extra_concat_orig = [] self.extra_concat = None + self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): @@ -115,6 +120,14 @@ def get_models(self): if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out + + def get_extra_hooks(self): + out = [] + if self.extra_hooks is not None: + out.append(self.extra_hooks) + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_extra_hooks() + return out def copy_to(self, c): c.cond_hint_original = self.cond_hint_original @@ -130,6 +143,7 @@ def copy_to(self, c): c.strength_type = self.strength_type c.concat_mask = self.concat_mask c.extra_concat_orig = self.extra_concat_orig.copy() + c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None c.preprocess_image = self.preprocess_image def inference_memory_requirements(self, dtype): @@ -200,10 +214,10 @@ def __init__(self, control_model=None, global_average_pooling=False, compression self.concat_mask = concat_mask self.preprocess_image = preprocess_image - def get_control(self, x_noisy, t, cond, batched_number): + def get_control(self, x_noisy, t, cond, batched_number, transformer_options): control_prev = None if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: @@ -758,10 +772,10 @@ def scale_image_to(self, width, height): height = math.ceil(height / unshuffle_amount) * unshuffle_amount return width, height - def get_control(self, x_noisy, t, cond, batched_number): + def get_control(self, x_noisy, t, cond, batched_number, transformer_options): control_prev = None if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: diff --git a/comfy/hooks.py b/comfy/hooks.py new file mode 100644 index 00000000000..ccb8183b91d --- /dev/null +++ b/comfy/hooks.py @@ -0,0 +1,690 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Callable +import enum +import math +import torch +import numpy as np +import itertools + +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher, PatcherInjection + from comfy.model_base import BaseModel + from comfy.sd import CLIP +import comfy.lora +import comfy.model_management +import comfy.patcher_extension +from node_helpers import conditioning_set_values + +class EnumHookMode(enum.Enum): + MinVram = "minvram" + MaxSpeed = "maxspeed" + +class EnumHookType(enum.Enum): + Weight = "weight" + Patch = "patch" + ObjectPatch = "object_patch" + AddModels = "add_models" + Callbacks = "callbacks" + Wrappers = "wrappers" + SetInjections = "add_injections" + +class EnumWeightTarget(enum.Enum): + Model = "model" + Clip = "clip" + +class _HookRef: + pass + +# NOTE: this is an example of how the should_register function should look +def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + return True + + +class Hook: + def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, + hook_keyframe: 'HookKeyframeGroup'=None): + self.hook_type = hook_type + self.hook_ref = hook_ref if hook_ref else _HookRef() + self.hook_id = hook_id + self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + self.custom_should_register = default_should_register + self.auto_apply_to_nonpositive = False + + @property + def strength(self): + return self.hook_keyframe.strength + + def initialize_timesteps(self, model: 'BaseModel'): + self.reset() + self.hook_keyframe.initialize_timesteps(model) + + def reset(self): + self.hook_keyframe.reset() + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: Hook = subtype() + c.hook_type = self.hook_type + c.hook_ref = self.hook_ref + c.hook_id = self.hook_id + c.hook_keyframe = self.hook_keyframe + c.custom_should_register = self.custom_should_register + # TODO: make this do something + c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive + return c + + def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + return self.custom_should_register(self, model, model_options, target, registered) + + def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") + + def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): + pass + + def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): + pass + + def __eq__(self, other: 'Hook'): + return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref + + def __hash__(self): + return hash(self.hook_ref) + +class WeightHook(Hook): + def __init__(self, strength_model=1.0, strength_clip=1.0): + super().__init__(hook_type=EnumHookType.Weight) + self.weights: dict = None + self.weights_clip: dict = None + self.need_weight_init = True + self._strength_model = strength_model + self._strength_clip = strength_clip + + @property + def strength_model(self): + return self._strength_model * self.strength + + @property + def strength_clip(self): + return self._strength_clip * self.strength + + def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + if not self.should_register(model, model_options, target, registered): + return False + weights = None + if target == EnumWeightTarget.Model: + strength = self._strength_model + else: + strength = self._strength_clip + + if self.need_weight_init: + key_map = {} + if target == EnumWeightTarget.Model: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + else: + key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) + weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) + else: + if target == EnumWeightTarget.Model: + weights = self.weights + else: + weights = self.weights_clip + k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) + registered.append(self) + return True + # TODO: add logs about any keys that were not applied + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: WeightHook = super().clone(subtype) + c.weights = self.weights + c.weights_clip = self.weights_clip + c.need_weight_init = self.need_weight_init + c._strength_model = self._strength_model + c._strength_clip = self._strength_clip + return c + +class PatchHook(Hook): + def __init__(self): + super().__init__(hook_type=EnumHookType.Patch) + self.patches: dict = None + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: PatchHook = super().clone(subtype) + c.patches = self.patches + return c + # TODO: add functionality + +class ObjectPatchHook(Hook): + def __init__(self): + super().__init__(hook_type=EnumHookType.ObjectPatch) + self.object_patches: dict = None + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: ObjectPatchHook = super().clone(subtype) + c.object_patches = self.object_patches + return c + # TODO: add functionality + +class AddModelsHook(Hook): + def __init__(self, key: str=None, models: list['ModelPatcher']=None): + super().__init__(hook_type=EnumHookType.AddModels) + self.key = key + self.models = models + self.append_when_same = True + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: AddModelsHook = super().clone(subtype) + c.key = self.key + c.models = self.models.copy() if self.models else self.models + c.append_when_same = self.append_when_same + return c + # TODO: add functionality + +class CallbackHook(Hook): + def __init__(self, key: str=None, callback: Callable=None): + super().__init__(hook_type=EnumHookType.Callbacks) + self.key = key + self.callback = callback + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: CallbackHook = super().clone(subtype) + c.key = self.key + c.callback = self.callback + return c + # TODO: add functionality + +class WrapperHook(Hook): + def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): + super().__init__(hook_type=EnumHookType.Wrappers) + self.wrappers_dict = wrappers_dict + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: WrapperHook = super().clone(subtype) + c.wrappers_dict = self.wrappers_dict + return c + + def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + if not self.should_register(model, model_options, target, registered): + return False + add_model_options = {"transformer_options": self.wrappers_dict} + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + registered.append(self) + return True + +class SetInjectionsHook(Hook): + def __init__(self, key: str=None, injections: list['PatcherInjection']=None): + super().__init__(hook_type=EnumHookType.SetInjections) + self.key = key + self.injections = injections + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: SetInjectionsHook = super().clone(subtype) + c.key = self.key + c.injections = self.injections.copy() if self.injections else self.injections + return c + + def add_hook_injections(self, model: 'ModelPatcher'): + # TODO: add functionality + pass + +class HookGroup: + def __init__(self): + self.hooks: list[Hook] = [] + + def add(self, hook: Hook): + if hook not in self.hooks: + self.hooks.append(hook) + + def contains(self, hook: Hook): + return hook in self.hooks + + def clone(self): + c = HookGroup() + for hook in self.hooks: + c.add(hook.clone()) + return c + + def clone_and_combine(self, other: 'HookGroup'): + c = self.clone() + if other is not None: + for hook in other.hooks: + c.add(hook.clone()) + return c + + def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): + if hook_kf is None: + hook_kf = HookKeyframeGroup() + else: + hook_kf = hook_kf.clone() + for hook in self.hooks: + hook.hook_keyframe = hook_kf + + def get_dict_repr(self): + d: dict[EnumHookType, dict[Hook, None]] = {} + for hook in self.hooks: + with_type = d.setdefault(hook.hook_type, {}) + with_type[hook] = None + return d + + def get_hooks_for_clip_schedule(self): + scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} + for hook in self.hooks: + # only care about WeightHooks, for now + if hook.hook_type == EnumHookType.Weight: + hook_schedule = [] + # if no hook keyframes, assign default value + if len(hook.hook_keyframe.keyframes) == 0: + hook_schedule.append(((0.0, 1.0), None)) + scheduled_hooks[hook] = hook_schedule + continue + # find ranges of values + prev_keyframe = hook.hook_keyframe.keyframes[0] + for keyframe in hook.hook_keyframe.keyframes: + if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): + hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) + prev_keyframe = keyframe + elif keyframe.start_percent == prev_keyframe.start_percent: + prev_keyframe = keyframe + # create final range, assuming last start_percent was not 1.0 + if not math.isclose(prev_keyframe.start_percent, 1.0): + hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + scheduled_hooks[hook] = hook_schedule + # hooks should not have their schedules in a list of tuples + all_ranges: list[tuple[float, float]] = [] + for range_kfs in scheduled_hooks.values(): + for t_range, keyframe in range_kfs: + all_ranges.append(t_range) + # turn list of ranges into boundaries + boundaries_set = set(itertools.chain.from_iterable(all_ranges)) + boundaries_set.add(0.0) + boundaries = sorted(boundaries_set) + real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)] + # with real ranges defined, give appropriate hooks w/ keyframes for each range + scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = [] + for t_range in real_ranges: + hooks_schedule = [] + for hook, val in scheduled_hooks.items(): + keyframe = None + # check if is a keyframe that works for the current t_range + for stored_range, stored_kf in val: + # if stored start is less than current end, then fits - give it assigned keyframe + if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]: + keyframe = stored_kf + break + hooks_schedule.append((hook, keyframe)) + scheduled_keyframes.append((t_range, hooks_schedule)) + return scheduled_keyframes + + def reset(self): + for hook in self.hooks: + hook.reset() + + @staticmethod + def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': + actual: list[HookGroup] = [] + for group in hooks_list: + if group is not None: + actual.append(group) + if len(actual) < require_count: + raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.") + # if no hooks, then return None + if len(actual) == 0: + return None + # if only 1 hook, just return itself without cloning + elif len(actual) == 1: + return actual[0] + final_hook: HookGroup = None + for hook in actual: + if final_hook is None: + final_hook = hook.clone() + else: + final_hook = final_hook.clone_and_combine(hook) + return final_hook + + +class HookKeyframe: + def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1): + self.strength = strength + # scheduling + self.start_percent = float(start_percent) + self.start_t = 999999999.9 + self.guarantee_steps = guarantee_steps + + def clone(self): + c = HookKeyframe(strength=self.strength, + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) + c.start_t = self.start_t + return c + +class HookKeyframeGroup: + def __init__(self): + self.keyframes: list[HookKeyframe] = [] + self._current_keyframe: HookKeyframe = None + self._current_used_steps = 0 + self._current_index = 0 + self._current_strength = None + self._curr_t = -1. + + # properties shadow those of HookWeightsKeyframe + @property + def strength(self): + if self._current_keyframe is not None: + return self._current_keyframe.strength + return 1.0 + + def reset(self): + self._current_keyframe = None + self._current_used_steps = 0 + self._current_index = 0 + self._current_strength = None + self.curr_t = -1. + self._set_first_as_current() + + def add(self, keyframe: HookKeyframe): + # add to end of list, then sort + self.keyframes.append(keyframe) + self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") + self._set_first_as_current() + + def _set_first_as_current(self): + if len(self.keyframes) > 0: + self._current_keyframe = self.keyframes[0] + else: + self._current_keyframe = None + + def has_index(self, index: int): + return index >= 0 and index < len(self.keyframes) + + def is_empty(self): + return len(self.keyframes) == 0 + + def clone(self): + c = HookKeyframeGroup() + for keyframe in self.keyframes: + c.keyframes.append(keyframe.clone()) + c._set_first_as_current() + return c + + def initialize_timesteps(self, model: 'BaseModel'): + for keyframe in self.keyframes: + keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) + + def prepare_current_keyframe(self, curr_t: float) -> bool: + if self.is_empty(): + return False + if curr_t == self._curr_t: + return False + prev_index = self._current_index + prev_strength = self._current_strength + # if met guaranteed steps, look for next keyframe in case need to switch + if self._current_used_steps >= self._current_keyframe.guarantee_steps: + # if has next index, loop through and see if need to switch + if self.has_index(self._current_index+1): + for i in range(self._current_index+1, len(self.keyframes)): + eval_c = self.keyframes[i] + # check if start_t is greater or equal to curr_t + # NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling + if eval_c.start_t >= curr_t: + self._current_index = i + self._current_strength = eval_c.strength + self._current_keyframe = eval_c + self._current_used_steps = 0 + # if guarantee_steps greater than zero, stop searching for other keyframes + if self._current_keyframe.guarantee_steps > 0: + break + # if eval_c is outside the percent range, stop looking further + else: break + # update steps current context is used + self._current_used_steps += 1 + # update current timestep this was performed on + self._curr_t = curr_t + # return True if keyframe changed, False if no change + return prev_index != self._current_index and prev_strength != self._current_strength + + +class InterpolationMethod: + LINEAR = "linear" + EASE_IN = "ease_in" + EASE_OUT = "ease_out" + EASE_IN_OUT = "ease_in_out" + + _LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT] + + @classmethod + def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False): + diff = num_to - num_from + if method == cls.LINEAR: + weights = torch.linspace(num_from, num_to, length) + elif method == cls.EASE_IN: + index = torch.linspace(0, 1, length) + weights = diff * np.power(index, 2) + num_from + elif method == cls.EASE_OUT: + index = torch.linspace(0, 1, length) + weights = diff * (1 - np.power(1 - index, 2)) + num_from + elif method == cls.EASE_IN_OUT: + index = torch.linspace(0, 1, length) + weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from + else: + raise ValueError(f"Unrecognized interpolation method '{method}'.") + if reverse: + weights = weights.flip(dims=(0,)) + return weights + +def get_sorted_list_via_attr(objects: list, attr: str) -> list: + if not objects: + return objects + elif len(objects) <= 1: + return [x for x in objects] + # now that we know we have to sort, do it following these rules: + # a) if objects have same value of attribute, maintain their relative order + # b) perform sorting of the groups of objects with same attributes + unique_attrs = {} + for o in objects: + val_attr = getattr(o, attr) + attr_list: list = unique_attrs.get(val_attr, list()) + attr_list.append(o) + if val_attr not in unique_attrs: + unique_attrs[val_attr] = attr_list + # now that we have the unique attr values grouped together in relative order, sort them by key + sorted_attrs = dict(sorted(unique_attrs.items())) + # now flatten out the dict into a list to return + sorted_list = [] + for object_list in sorted_attrs.values(): + sorted_list.extend(object_list) + return sorted_list + +def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): + hook_group = HookGroup() + hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) + hook_group.add(hook) + hook.weights = lora + return hook_group + +def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float): + hook_group = HookGroup() + hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) + hook_group.add(hook) + patches_model = None + patches_clip = None + if weights_model is not None: + patches_model = {} + for key in weights_model: + patches_model[key] = ("model_as_lora", (weights_model[key],)) + if weights_clip is not None: + patches_clip = {} + for key in weights_clip: + patches_clip[key] = ("model_as_lora", (weights_clip[key],)) + hook.weights = patches_model + hook.weights_clip = patches_clip + hook.need_weight_init = False + return hook_group + +def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): + if model is None: + return None + patches_model: dict[str, torch.Tensor] = model.model.state_dict() + if discard_model_sampling: + # do not include ANY model_sampling components of the model that should act as a patch + for key in list(patches_model.keys()): + if key.startswith("model_sampling"): + patches_model.pop(key, None) + return patches_model + +# NOTE: this function shows how to register weight hooks directly on the ModelPatchers +def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], + strength_model: float, strength_clip: float): + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + hook_group = HookGroup() + hook = WeightHook() + hook_group.add(hook) + loaded: dict[str] = comfy.lora.load_lora(lora, key_map) + if model is not None: + new_modelpatcher = model.clone() + k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model) + else: + k = () + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip) + else: + k1 = () + new_clip = None + k = set(k) + k1 = set(k1) + for x in loaded: + if (x not in k) and (x not in k1): + print(f"NOT LOADED {x}") + return (new_modelpatcher, new_clip, hook_group) + +def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]): + hooks_key = 'hooks' + # if hooks only exist in one dict, do what's needed so that it ends up in c_dict + if hooks_key not in values: + return + if hooks_key not in c_dict: + hooks_value = values.get(hooks_key, None) + if hooks_value is not None: + c_dict[hooks_key] = hooks_value + return + # otherwise, need to combine with minimum duplication via cache + hooks_tuple = (c_dict[hooks_key], values[hooks_key]) + cached_hooks = cache.get(hooks_tuple, None) + if cached_hooks is None: + new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1]) + cache[hooks_tuple] = new_hooks + c_dict[hooks_key] = new_hooks + else: + c_dict[hooks_key] = cache[hooks_tuple] + +def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): + c = [] + hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} + for t in conditioning: + n = [t[0], t[1].copy()] + for k in values: + if append_hooks and k == 'hooks': + _combine_hooks_from_values(n[1], values, hooks_combine_cache) + else: + n[1][k] = values[k] + c.append(n) + + return c + +def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): + if hooks is None: + return cond + return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) + +def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): + if timestep_range is None: + return cond + return conditioning_set_values(cond, {"start_percent": timestep_range[0], + "end_percent": timestep_range[1]}) + +def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float): + if mask is None: + return cond + set_area_to_bounds = False + if set_cond_area != 'default': + set_area_to_bounds = True + if len(mask.shape) < 3: + mask = mask.unsqueeze(0) + return conditioning_set_values(cond, {'mask': mask, + 'set_area_to_bounds': set_area_to_bounds, + 'mask_strength': strength}) + +def combine_conditioning(conds: list): + combined_conds = [] + for cond in conds: + combined_conds.extend(cond) + return combined_conds + +def combine_with_new_conds(conds: list, new_conds: list): + combined_conds = [] + for c, new_c in zip(conds, new_conds): + combined_conds.append(combine_conditioning([c, new_c])) + return combined_conds + +def set_conds_props(conds: list, strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): + final_conds = [] + for c in conds: + # first, apply lora_hook to conditioning, if provided + c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) + # next, apply mask to conditioning + c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) + # apply timesteps, if present + c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range) + # finally, apply mask to conditioning and store + final_conds.append(c) + return final_conds + +def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", + mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): + combined_conds = [] + for c, masked_c in zip(conds, new_conds): + # first, apply lora_hook to new conditioning, if provided + masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) + # next, apply mask to new conditioning, if provided + masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) + # apply timesteps, if present + masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range) + # finally, combine with existing conditioning and store + combined_conds.append(combine_conditioning([c, masked_c])) + return combined_conds + +def set_default_conds_and_combine(conds: list, new_conds: list, + hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): + combined_conds = [] + for c, new_c in zip(conds, new_conds): + # first, apply lora_hook to new conditioning, if provided + new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) + # next, add default_cond key to cond so that during sampling, it can be identified + new_c = conditioning_set_values(new_c, {'default': True}) + # apply timesteps, if present + new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range) + # finally, combine with existing conditioning and store + combined_conds.append(combine_conditioning([c, new_c])) + return combined_conds diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 2902073d5ea..3f7fee708ff 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -15,6 +15,7 @@ ) from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists +import comfy.patcher_extension import comfy.ops ops = comfy.ops.disable_weight_init @@ -47,6 +48,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: + if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: + found_patched = False + for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: + if isinstance(layer, class_type): + x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) + found_patched = True + break + if found_patched: + continue x = layer(x) return x @@ -819,6 +829,13 @@ def get_resblock( ) def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timesteps, context, y, control, transformer_options, **kwargs) + + def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. diff --git a/comfy/lora.py b/comfy/lora.py index 1080169b191..b6d9a8d04f8 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -33,7 +33,7 @@ } -def load_lora(lora, to_load): +def load_lora(lora, to_load, log_missing=True): patch_dict = {} loaded_keys = set() for x in to_load: @@ -213,9 +213,10 @@ def load_lora(lora, to_load): patch_dict[to_load[x]] = ("set", (set_weight,)) loaded_keys.add(set_weight_name) - for x in lora.keys(): - if x not in loaded_keys: - logging.warning("lora key not loaded: {}".format(x)) + if log_missing: + for x in lora.keys(): + if x not in loaded_keys: + logging.warning("lora key not loaded: {}".format(x)) return patch_dict @@ -429,7 +430,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten return padded_tensor -def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): +def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None): for p in patches: strength = p[0] v = p[1] @@ -471,6 +472,11 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) elif patch_type == "set": weight.copy_(v[0]) + elif patch_type == "model_as_lora": + target_weight: torch.Tensor = v[0] + diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ + comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) + weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) diff --git a/comfy/model_base.py b/comfy/model_base.py index c305014a4a3..8f37af660d6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -33,12 +33,16 @@ import comfy.ldm.lightricks.model import comfy.model_management +import comfy.patcher_extension import comfy.conds import comfy.ops from enum import Enum from . import utils import comfy.latent_formats import math +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher class ModelType(Enum): EPS = 1 @@ -95,6 +99,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod self.model_config = model_config self.manual_cast_dtype = model_config.manual_cast_dtype self.device = device + self.current_patcher: 'ModelPatcher' = None if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: @@ -120,6 +125,13 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._apply_model, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options) + ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b64c795aa50..4ae3ad25db2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -16,6 +16,8 @@ along with this program. If not, see . """ +from __future__ import annotations +from typing import Optional, Callable import torch import copy import inspect @@ -28,6 +30,9 @@ import comfy.float import comfy.model_management import comfy.lora +import comfy.hooks +import comfy.patcher_extension +from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.comfy_types import UnetWrapperFunction def string_to_seed(data): @@ -76,6 +81,17 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ model_options["disable_cfg1_optimization"] = True return model_options +def create_model_options_clone(orig_model_options: dict): + return comfy.patcher_extension.copy_nested_dicts(orig_model_options) + +def create_hook_patches_clone(orig_hook_patches): + new_hook_patches = {} + for hook_ref in orig_hook_patches: + new_hook_patches[hook_ref] = {} + for k in orig_hook_patches[hook_ref]: + new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] + return new_hook_patches + def wipe_lowvram_weight(m): if hasattr(m, "prev_comfy_cast_weights"): m.comfy_cast_weights = m.prev_comfy_cast_weights @@ -119,6 +135,49 @@ def get_key_weight(model, key): return weight, set_func, convert_func +class AutoPatcherEjector: + def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): + self.model = model + self.was_injected = False + self.prev_skip_injection = False + self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only + + def __enter__(self): + self.was_injected = False + self.prev_skip_injection = self.model.skip_injection + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = True + if self.model.is_injected: + self.model.eject_model() + self.was_injected = True + + def __exit__(self, *args): + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = self.prev_skip_injection + self.model.inject_model() + if self.was_injected and not self.model.skip_injection: + self.model.inject_model() + self.model.skip_injection = self.prev_skip_injection + +class MemoryCounter: + def __init__(self, initial: int, minimum=0): + self.value = initial + self.minimum = minimum + # TODO: add a safe limit besides 0 + + def use(self, weight: torch.Tensor): + weight_size = weight.nelement() * weight.element_size() + if self.is_useable(weight_size): + self.decrement(weight_size) + return True + return False + + def is_useable(self, used: int): + return self.value - used > self.minimum + + def decrement(self, used: int): + self.value -= used + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -141,6 +200,24 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.patches_uuid = uuid.uuid4() self.parent = None + self.attachments: dict[str] = {} + self.additional_models: dict[str, list[ModelPatcher]] = {} + self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks() + self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers() + + self.is_injected = False + self.skip_injection = False + self.injections: dict[str, list[PatcherInjection]] = {} + + self.hook_patches: dict[comfy.hooks._HookRef] = {} + self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} + self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} + self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} + self.current_hooks: Optional[comfy.hooks.HookGroup] = None + self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time + self.is_clip = False + self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -177,6 +254,47 @@ def clone(self): n.backup = self.backup n.object_patches_backup = self.object_patches_backup n.parent = self + + # attachments + n.attachments = {} + for k in self.attachments: + if hasattr(self.attachments[k], "on_model_patcher_clone"): + n.attachments[k] = self.attachments[k].on_model_patcher_clone() + else: + n.attachments[k] = self.attachments[k] + # additional models + for k, c in self.additional_models.items(): + n.additional_models[k] = [x.clone() for x in c] + # callbacks + for k, c in self.callbacks.items(): + n.callbacks[k] = {} + for k1, c1 in c.items(): + n.callbacks[k][k1] = c1.copy() + # sample wrappers + for k, w in self.wrappers.items(): + n.wrappers[k] = {} + for k1, w1 in w.items(): + n.wrappers[k][k1] = w1.copy() + # injection + n.is_injected = self.is_injected + n.skip_injection = self.skip_injection + for k, i in self.injections.items(): + n.injections[k] = i.copy() + # hooks + n.hook_patches = create_hook_patches_clone(self.hook_patches) + n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) + for group in self.cached_hook_patches: + n.cached_hook_patches[group] = {} + for k in self.cached_hook_patches[group]: + n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k] + n.hook_backup = self.hook_backup + n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks + n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks + n.is_clip = self.is_clip + n.hook_mode = self.hook_mode + + for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): + callback(self, n) return n def is_clone(self, other): @@ -184,10 +302,29 @@ def is_clone(self, other): return True return False - def clone_has_same_weights(self, clone): + def clone_has_same_weights(self, clone: 'ModelPatcher'): if not self.is_clone(clone): return False + if self.current_hooks != clone.current_hooks: + return False + if self.forced_hooks != clone.forced_hooks: + return False + if self.hook_patches.keys() != clone.hook_patches.keys(): + return False + if self.attachments.keys() != clone.attachments.keys(): + return False + if self.additional_models.keys() != clone.additional_models.keys(): + return False + for key in self.callbacks: + if len(self.callbacks[key]) != len(clone.callbacks[key]): + return False + for key in self.wrappers: + if len(self.wrappers[key]) != len(clone.wrappers[key]): + return False + if self.injections.keys() != clone.injections.keys(): + return False + if len(self.patches) == 0 and len(clone.patches) == 0: return True @@ -256,6 +393,12 @@ def set_model_input_block_patch_after_skip(self, patch): def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") + def set_model_emb_patch(self, patch): + self.set_model_patch(patch, "emb_patch") + + def set_model_forward_timestep_embed_patch(self, patch): + self.set_model_patch(patch, "forward_timestep_embed_patch") + def add_object_patch(self, name, obj): self.object_patches[name] = obj @@ -294,27 +437,28 @@ def model_dtype(self): return self.model.get_dtype() def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] + with self.use_ejected(): + p = set() + model_sd = self.model.state_dict() + for k in patches: + offset = None + function = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] - if key in model_sd: - p.add(k) - current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - self.patches[key] = current_patches + if key in model_sd: + p.add(k) + current_patches = self.patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + self.patches[key] = current_patches - self.patches_uuid = uuid.uuid4() - return list(p) + self.patches_uuid = uuid.uuid4() + return list(p) def get_key_patches(self, filter_prefix=None): model_sd = self.model_state_dict() @@ -324,9 +468,12 @@ def get_key_patches(self, filter_prefix=None): if not k.startswith(filter_prefix): continue bk = self.backup.get(k, None) + hbk = self.hook_backup.get(k, None) weight, set_func, convert_func = get_key_weight(self.model, k) if bk is not None: weight = bk.weight + if hbk is not None: + weight = hbk[0] if convert_func is None: convert_func = lambda a, **kwargs: a @@ -337,13 +484,14 @@ def get_key_patches(self, filter_prefix=None): return p def model_state_dict(self, filter_prefix=None): - sd = self.model.state_dict() - keys = list(sd.keys()) - if filter_prefix is not None: - for k in keys: - if not k.startswith(filter_prefix): - sd.pop(k) - return sd + with self.use_ejected(): + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd def patch_weight_to_device(self, key, device_to=None, inplace_update=False): if key not in self.patches: @@ -388,106 +536,117 @@ def _load_list(self): return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): - mem_counter = 0 - patch_counter = 0 - lowvram_counter = 0 - loading = self._load_list() - - load_completely = [] - loading.sort(reverse=True) - for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] - - lowvram_weight = False - - if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: - lowvram_weight = True - lowvram_counter += 1 - if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed - continue + with self.use_ejected(): + self.unpatch_hooks() + mem_counter = 0 + patch_counter = 0 + lowvram_counter = 0 + loading = self._load_list() + + load_completely = [] + loading.sort(reverse=True) + for x in loading: + n = x[1] + m = x[2] + params = x[3] + module_mem = x[0] + + lowvram_weight = False + + if not full_load and hasattr(m, "comfy_cast_weights"): + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + lowvram_counter += 1 + if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed + continue - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) - if lowvram_weight: - if weight_key in self.patches: - if force_patch_weights: - self.patch_weight_to_device(weight_key) - else: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - if force_patch_weights: - self.patch_weight_to_device(bias_key) - else: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 + if lowvram_weight: + if weight_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - else: - if hasattr(m, "comfy_cast_weights"): - if m.comfy_cast_weights: - wipe_lowvram_weight(m) - - if full_load or mem_counter + module_mem < lowvram_model_memory: - mem_counter += module_mem - load_completely.append((module_mem, n, m, params)) - - load_completely.sort(reverse=True) - for x in load_completely: - n = x[1] - m = x[2] - params = x[3] - if hasattr(m, "comfy_patched_weights"): - if m.comfy_patched_weights == True: - continue + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "comfy_cast_weights"): + if m.comfy_cast_weights: + wipe_lowvram_weight(m) + + if full_load or mem_counter + module_mem < lowvram_model_memory: + mem_counter += module_mem + load_completely.append((module_mem, n, m, params)) + + load_completely.sort(reverse=True) + for x in load_completely: + n = x[1] + m = x[2] + params = x[3] + if hasattr(m, "comfy_patched_weights"): + if m.comfy_patched_weights == True: + continue - for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + for param in params: + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) - logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) - m.comfy_patched_weights = True + logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + m.comfy_patched_weights = True - for x in load_completely: - x[2].to(device_to) + for x in load_completely: + x[2].to(device_to) - if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) - self.model.model_lowvram = True - else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) - self.model.model_lowvram = False - if full_load: - self.model.to(device_to) - mem_counter = self.model_size() + if lowvram_counter > 0: + logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + self.model.model_lowvram = True + else: + logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + self.model.model_lowvram = False + if full_load: + self.model.to(device_to) + mem_counter = self.model_size() - self.model.lowvram_patch_counter += patch_counter - self.model.device = device_to - self.model.model_loaded_weight_memory = mem_counter - self.model.current_weight_patches_uuid = self.patches_uuid + self.model.lowvram_patch_counter += patch_counter + self.model.device = device_to + self.model.model_loaded_weight_memory = mem_counter + self.model.current_weight_patches_uuid = self.patches_uuid - def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): - for k in self.object_patches: - old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) - if k not in self.object_patches_backup: - self.object_patches_backup[k] = old + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) - if lowvram_model_memory == 0: - full_load = True - else: - full_load = False + self.apply_hooks(self.forced_hooks, force_apply=True) + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + with self.use_ejected(): + for k in self.object_patches: + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old + + if lowvram_model_memory == 0: + full_load = True + else: + full_load = False - if load_weights: - self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) + if load_weights: + self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) + self.inject_model() return self.model def unpatch_model(self, device_to=None, unpatch_weights=True): + self.eject_model() if unpatch_weights: + self.unpatch_hooks() if self.model.model_lowvram: for m in self.model.modules(): wipe_lowvram_weight(m) @@ -523,85 +682,91 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.object_patches_backup.clear() def partially_unload(self, device_to, memory_to_free=0): - memory_freed = 0 - patch_counter = 0 - unload_list = self._load_list() - unload_list.sort() - for unload in unload_list: - if memory_to_free < memory_freed: - break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] - - lowvram_possible = hasattr(m, "comfy_cast_weights") - if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: - move_weight = True - for param in params: - key = "{}.{}".format(n, param) - bk = self.backup.get(key, None) - if bk is not None: - if not lowvram_possible: - move_weight = False - break - - if bk.inplace_update: - comfy.utils.copy_to_param(self.model, key, bk.weight) - else: - comfy.utils.set_attr_param(self.model, key, bk.weight) - self.backup.pop(key) - - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if move_weight: - m.to(device_to) - if lowvram_possible: - if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 - - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - m.comfy_patched_weights = False - memory_freed += module_mem - logging.debug("freed {}".format(n)) + with self.use_ejected(): + memory_freed = 0 + patch_counter = 0 + unload_list = self._load_list() + unload_list.sort() + for unload in unload_list: + if memory_to_free < memory_freed: + break + module_mem = unload[0] + n = unload[1] + m = unload[2] + params = unload[3] + + lowvram_possible = hasattr(m, "comfy_cast_weights") + if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: + move_weight = True + for param in params: + key = "{}.{}".format(n, param) + bk = self.backup.get(key, None) + if bk is not None: + if not lowvram_possible: + move_weight = False + break + + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, key, bk.weight) + else: + comfy.utils.set_attr_param(self.model, key, bk.weight) + self.backup.pop(key) + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if move_weight: + m.to(device_to) + if lowvram_possible: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + m.comfy_patched_weights = False + memory_freed += module_mem + logging.debug("freed {}".format(n)) - self.model.model_lowvram = True - self.model.lowvram_patch_counter += patch_counter - self.model.model_loaded_weight_memory -= memory_freed - return memory_freed + self.model.model_lowvram = True + self.model.lowvram_patch_counter += patch_counter + self.model.model_loaded_weight_memory -= memory_freed + return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): - unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) - # TODO: force_patch_weights should not unload + reload full model - used = self.model.model_loaded_weight_memory - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) - if unpatch_weights: - extra_memory += (used - self.model.model_loaded_weight_memory) - - self.patch_model(load_weights=False) - full_load = False - if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: - return 0 - if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): - full_load = True - current_used = self.model.model_loaded_weight_memory - try: - self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) - except Exception as e: - self.detach() - raise e - - return self.model.model_loaded_weight_memory - current_used + with self.use_ejected(skip_and_inject_on_exit_only=True): + unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) + # TODO: force_patch_weights should not unload + reload full model + used = self.model.model_loaded_weight_memory + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + if unpatch_weights: + extra_memory += (used - self.model.model_loaded_weight_memory) + + self.patch_model(load_weights=False) + full_load = False + if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: + self.apply_hooks(self.forced_hooks, force_apply=True) + return 0 + if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): + full_load = True + current_used = self.model.model_loaded_weight_memory + try: + self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) + except Exception as e: + self.detach() + raise e + + return self.model.model_loaded_weight_memory - current_used def detach(self, unpatch_all=True): + self.eject_model() self.model_patches_to(self.offload_device) if unpatch_all: self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): + callback(self, unpatch_all) return self.model def current_loaded_device(self): @@ -611,6 +776,345 @@ def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float3 print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) + def cleanup(self): + self.clean_hooks() + if hasattr(self.model, "current_patcher"): + self.model.current_patcher = None + for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP): + callback(self) + + def add_callback(self, call_type: str, callback: Callable): + self.add_callback_with_key(call_type, None, callback) + + def add_callback_with_key(self, call_type: str, key: str, callback: Callable): + c = self.callbacks.setdefault(call_type, {}).setdefault(key, []) + c.append(callback) + + def remove_callbacks_with_key(self, call_type: str, key: str): + c = self.callbacks.get(call_type, {}) + if key in c: + c.pop(key) + + def get_callbacks(self, call_type: str, key: str): + return self.callbacks.get(call_type, {}).get(key, []) + + def get_all_callbacks(self, call_type: str): + c_list = [] + for c in self.callbacks.get(call_type, {}).values(): + c_list.extend(c) + return c_list + + def add_wrapper(self, wrapper_type: str, wrapper: Callable): + self.add_wrapper_with_key(wrapper_type, None, wrapper) + + def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): + w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) + w.append(wrapper) + + def remove_wrappers_with_key(self, wrapper_type: str, key: str): + w = self.wrappers.get(wrapper_type, {}) + if key in w: + w.pop(key) + + def get_wrappers(self, wrapper_type: str, key: str): + return self.wrappers.get(wrapper_type, {}).get(key, []) + + def get_all_wrappers(self, wrapper_type: str): + w_list = [] + for w in self.wrappers.get(wrapper_type, {}).values(): + w_list.extend(w) + return w_list + + def set_attachments(self, key: str, attachment): + self.attachments[key] = attachment + + def remove_attachments(self, key: str): + if key in self.attachments: + self.attachments.pop(key) + + def get_attachment(self, key: str): + return self.attachments.get(key, None) + + def set_injections(self, key: str, injections: list[PatcherInjection]): + self.injections[key] = injections + + def remove_injections(self, key: str): + if key in self.injections: + self.injections.pop(key) + + def set_additional_models(self, key: str, models: list['ModelPatcher']): + self.additional_models[key] = models + + def remove_additional_models(self, key: str): + if key in self.additional_models: + self.additional_models.pop(key) + + def get_additional_models_with_key(self, key: str): + return self.additional_models.get(key, []) + + def get_additional_models(self): + all_models = [] + for models in self.additional_models.values(): + all_models.extend(models) + return all_models + + def get_nested_additional_models(self): + def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]): + '''Make sure circular references do not cause infinite recursion.''' + next_models = [] + for model in prev_models: + candidates = model.get_additional_models() + for c in candidates: + if c not in cache_set: + next_models.append(c) + cache_set.add(c) + if len(next_models) == 0: + return prev_models + return prev_models + _evaluate_sub_additional_models(next_models, cache_set) + + all_models = self.get_additional_models() + models_set = set(all_models) + real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set) + return real_all_models + + def use_ejected(self, skip_and_inject_on_exit_only=False): + return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only) + + def inject_model(self): + if self.is_injected or self.skip_injection: + return + for injections in self.injections.values(): + for inj in injections: + inj.inject(self) + self.is_injected = True + if self.is_injected: + for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL): + callback(self) + + def eject_model(self): + if not self.is_injected: + return + for injections in self.injections.values(): + for inj in injections: + inj.eject(self) + self.is_injected = False + for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL): + callback(self) + + def pre_run(self): + if hasattr(self.model, "current_patcher"): + self.model.current_patcher = self + for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): + callback(self) + + def prepare_state(self, timestep): + for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): + callback(self, timestep) + + def restore_hook_patches(self): + if len(self.hook_patches_backup) > 0: + self.hook_patches = self.hook_patches_backup + self.hook_patches_backup = {} + + def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): + self.hook_mode = hook_mode + + def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): + curr_t = t[0] + reset_current_hooks = False + for hook in hook_group.hooks: + changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t) + # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; + # this will cause the weights to be recalculated when sampling + if changed: + # reset current_hooks if contains hook that changed + if self.current_hooks is not None: + for current_hook in self.current_hooks.hooks: + if current_hook == hook: + reset_current_hooks = True + break + for cached_group in list(self.cached_hook_patches.keys()): + if cached_group.contains(hook): + self.cached_hook_patches.pop(cached_group) + if reset_current_hooks: + self.patch_hooks(None) + + def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): + self.restore_hook_patches() + registered_hooks: list[comfy.hooks.Hook] = [] + # handle WrapperHooks, if model_options provided + if model_options is not None: + for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): + hook.add_hook_patches(self, model_options, target, registered_hooks) + # handle WeightHooks + weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] + for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + if hook.hook_ref not in self.hook_patches: + weight_hooks_to_register.append(hook) + if len(weight_hooks_to_register) > 0: + # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state + self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) + for hook in weight_hooks_to_register: + hook.add_hook_patches(self, model_options, target, registered_hooks) + for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): + callback(self, hooks_dict, target) + + def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): + with self.use_ejected(): + # NOTE: this mirrors behavior of add_patches func + current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {}) + p = set() + model_sd = self.model.state_dict() + for k in patches: + offset = None + function = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] + + if key in model_sd: + p.add(k) + current_patches: list[tuple] = current_hook_patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + current_hook_patches[key] = current_patches + self.hook_patches[hook.hook_ref] = current_hook_patches + # since should care about these patches too to determine if same model, reroll patches_uuid + self.patches_uuid = uuid.uuid4() + return list(p) + + def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup): + # combined_patches will contain weights of all relevant hooks, per key + combined_patches = {} + if hooks is not None: + for hook in hooks.hooks: + hook_patches: dict = self.hook_patches.get(hook.hook_ref, {}) + for key in hook_patches.keys(): + current_patches: list[tuple] = combined_patches.get(key, []) + if math.isclose(hook.strength, 1.0): + current_patches.extend(hook_patches[key]) + else: + # patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model) + for patch in hook_patches[key]: + new_patch = list(patch) + new_patch[0] *= hook.strength + current_patches.append(tuple(new_patch)) + combined_patches[key] = current_patches + return combined_patches + + def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): + # TODO: return transformer_options dict with any additions from hooks + if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): + return {} + self.patch_hooks(hooks=hooks) + for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): + callback(self, hooks) + return {} + + def patch_hooks(self, hooks: comfy.hooks.HookGroup): + with self.use_ejected(): + self.unpatch_hooks() + if hooks is not None: + model_sd_keys = list(self.model_state_dict().keys()) + memory_counter = None + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + # TODO: minimum_counter should have a minimum that conforms to loaded model requirements + memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device), + minimum=comfy.model_management.minimum_inference_memory()*2) + # if have cached weights for hooks, use it + cached_weights = self.cached_hook_patches.get(hooks, None) + if cached_weights is not None: + for key in cached_weights: + if key not in model_sd_keys: + print(f"WARNING cached hook could not patch. key does not exist in model: {key}") + continue + self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) + else: + relevant_patches = self.get_combined_hook_patches(hooks=hooks) + original_weights = None + if len(relevant_patches) > 0: + original_weights = self.get_key_patches() + for key in relevant_patches: + if key not in model_sd_keys: + print(f"WARNING cached hook would not patch. key does not exist in model: {key}") + continue + self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, + memory_counter=memory_counter) + self.current_hooks = hooks + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + if key not in self.hook_backup: + weight: torch.Tensor = comfy.utils.get_attr(self.model, key) + target_device = self.offload_device + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + used = memory_counter.use(weight) + if used: + target_device = weight.device + self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) + comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1])) + + def clear_cached_hook_weights(self): + self.cached_hook_patches.clear() + self.patch_hooks(None) + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + weight, set_func, convert_func = get_key_weight(self.model, key) + weight: torch.Tensor + if key not in self.hook_backup: + target_device = self.offload_device + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + used = memory_counter.use(weight) + if used: + target_device = weight.device + self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) + # TODO: properly handle LowVramPatch, if it ends up an issue + temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True) + if convert_func is not None: + temp_weight = convert_func(temp_weight, inplace=True) + + out_weight = comfy.lora.calculate_weight(combined_patches[key], + temp_weight, + key, original_weights=original_weights) + del original_weights[key] + if set_func is None: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + # TODO: disable caching if not enough system RAM to do so + target_device = self.offload_device + used = memory_counter.use(weight) + if used: + target_device = weight.device + self.cached_hook_patches.setdefault(hooks, {}) + self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device) + del temp_weight + del out_weight + del weight + + def unpatch_hooks(self) -> None: + with self.use_ejected(): + if len(self.hook_backup) == 0: + self.current_hooks = None + return + keys = list(self.hook_backup.keys()) + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) + + self.hook_backup.clear() + self.current_hooks = None + + def clean_hooks(self): + self.unpatch_hooks() + self.clear_cached_hook_weights() + def __del__(self): self.detach(unpatch_all=False) diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py new file mode 100644 index 00000000000..5144691857b --- /dev/null +++ b/comfy/patcher_extension.py @@ -0,0 +1,156 @@ +from __future__ import annotations +from typing import Callable + +class CallbacksMP: + ON_CLONE = "on_clone" + ON_LOAD = "on_load_after" + ON_DETACH = "on_detach_after" + ON_CLEANUP = "on_cleanup" + ON_PRE_RUN = "on_pre_run" + ON_PREPARE_STATE = "on_prepare_state" + ON_APPLY_HOOKS = "on_apply_hooks" + ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches" + ON_INJECT_MODEL = "on_inject_model" + ON_EJECT_MODEL = "on_eject_model" + + # callbacks dict is in the format: + # {"call_type": {"key": [Callable1, Callable2, ...]} } + @classmethod + def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]: + return {} + +def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False): + add_callback_with_key(call_type, None, callback, transformer_options, is_model_options) + +def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.setdefault("transformer_options", {}) + callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {}) + c = callbacks.setdefault(call_type, {}).setdefault(key, []) + c.append(callback) + +def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.get("transformer_options", {}) + c_list = [] + callbacks: dict[str, list] = transformer_options.get("callbacks", {}) + c_list.extend(callbacks.get(call_type, {}).get(key, [])) + return c_list + +def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.get("transformer_options", {}) + c_list = [] + callbacks: dict[str, list] = transformer_options.get("callbacks", {}) + for c in callbacks.get(call_type, {}).values(): + c_list.extend(c) + return c_list + +class WrappersMP: + OUTER_SAMPLE = "outer_sample" + SAMPLER_SAMPLE = "sampler_sample" + CALC_COND_BATCH = "calc_cond_batch" + APPLY_MODEL = "apply_model" + DIFFUSION_MODEL = "diffusion_model" + + # wrappers dict is in the format: + # {"wrapper_type": {"key": [Callable1, Callable2, ...]} } + @classmethod + def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]: + return {} + +def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False): + add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options) + +def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.setdefault("transformer_options", {}) + wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {}) + w = wrappers.setdefault(wrapper_type, {}).setdefault(key, []) + w.append(wrapper) + +def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.get("transformer_options", {}) + w_list = [] + wrappers: dict[str, list] = transformer_options.get("wrappers", {}) + w_list.extend(wrappers.get(wrapper_type, {}).get(key, [])) + return w_list + +def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.get("transformer_options", {}) + w_list = [] + wrappers: dict[str, list] = transformer_options.get("wrappers", {}) + for w in wrappers.get(wrapper_type, {}).values(): + w_list.extend(w) + return w_list + +class WrapperExecutor: + """Handles call stack of wrappers around a function in an ordered manner.""" + def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int): + # NOTE: class_obj exists so that wrappers surrounding a class method can access + # the class instance at runtime via executor.class_obj + self.original = original + self.class_obj = class_obj + self.wrappers = wrappers.copy() + self.idx = idx + self.is_last = idx == len(wrappers) + + def __call__(self, *args, **kwargs): + """Calls the next wrapper or original function, whichever is appropriate.""" + new_executor = self._create_next_executor() + return new_executor.execute(*args, **kwargs) + + def execute(self, *args, **kwargs): + """Used to initiate executor internally - DO NOT use this if you received executor in wrapper.""" + args = list(args) + kwargs = dict(kwargs) + if self.is_last: + return self.original(*args, **kwargs) + return self.wrappers[self.idx](self, *args, **kwargs) + + def _create_next_executor(self) -> 'WrapperExecutor': + new_idx = self.idx + 1 + if new_idx > len(self.wrappers): + raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.") + if self.class_obj is None: + return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx) + return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx) + + @classmethod + def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0): + return cls(original, class_obj=None, wrappers=wrappers, idx=idx) + + @classmethod + def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0): + return cls(original, class_obj, wrappers, idx=idx) + +class PatcherInjection: + def __init__(self, inject: Callable, eject: Callable): + self.inject = inject + self.eject = eject + +def copy_nested_dicts(input_dict: dict): + new_dict = input_dict.copy() + for key, value in input_dict.items(): + if isinstance(value, dict): + new_dict[key] = copy_nested_dicts(value) + elif isinstance(value, list): + new_dict[key] = value.copy() + return new_dict + +def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True): + if copy_dict1: + merged_dict = copy_nested_dicts(dict1) + else: + merged_dict = dict1 + for key, value in dict2.items(): + if isinstance(value, dict): + curr_value = merged_dict.setdefault(key, {}) + merged_dict[key] = merge_nested_dicts(value, curr_value) + elif isinstance(value, list): + merged_dict.setdefault(key, []).extend(value) + else: + merged_dict[key] = value + return merged_dict diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1879e670a53..1252d8a5bf6 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,7 +1,16 @@ +from __future__ import annotations +import uuid import torch import comfy.model_management import comfy.conds import comfy.utils +import comfy.hooks +import comfy.patcher_extension +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.model_base import BaseModel + from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): return comfy.utils.reshape_mask(noise_mask, shape).to(device) @@ -10,9 +19,43 @@ def get_models_from_cond(cond, model_type): models = [] for c in cond: if model_type in c: - models += [c[model_type]] + if isinstance(c[model_type], list): + models += c[model_type] + else: + models += [c[model_type]] return models +def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): + # get hooks from conds, and collect cnets so they can be checked for extra_hooks + cnets: list[ControlBase] = [] + for c in cond: + if 'hooks' in c: + for hook in c['hooks'].hooks: + hook: comfy.hooks.Hook + with_type = hooks_dict.setdefault(hook.hook_type, {}) + with_type[hook] = None + if 'control' in c: + cnets.append(c['control']) + + def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list): + if cnet.extra_hooks is not None: + _list.append(cnet.extra_hooks) + if cnet.previous_controlnet is None: + return _list + return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list) + + hooks_list = [] + cnets = set(cnets) + for base_cnet in cnets: + get_extra_hooks_from_cnet(base_cnet, hooks_list) + extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) + if extra_hooks is not None: + for hook in extra_hooks.hooks: + with_type = hooks_dict.setdefault(hook.hook_type, {}) + with_type[hook] = None + + return hooks_dict + def convert_cond(cond): out = [] for c in cond: @@ -22,17 +65,22 @@ def convert_cond(cond): model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove temp["cross_attn"] = c[0] temp["model_conds"] = model_conds + temp["uuid"] = uuid.uuid4() out.append(temp) return out def get_additional_models(conds, dtype): """loads additional models in conditioning""" - cnets = [] + cnets: list[ControlBase] = [] gligen = [] + add_models = [] + hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") + add_models += get_models_from_cond(conds[k], "additional_models") + get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -43,7 +91,9 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - models = control_models + gligen + hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] + models = control_models + gligen + add_models + hook_models + return models, inference_memory def cleanup_additional_models(models): @@ -53,10 +103,11 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model, noise_shape, conds): +def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): device = model.load_device - real_model = None + real_model: 'BaseModel' = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) @@ -72,3 +123,14 @@ def cleanup_models(conds, models): control_cleanup += get_models_from_cond(conds[k], "control") cleanup_additional_models(set(control_cleanup)) + +def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): + # check for hooks in conds - if not registered, see if can be applied + hooks = {} + for k in conds: + get_hooks_from_cond(conds[k], hooks) + # add wrappers and callbacks from ModelPatcher to transformer_options + model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) + model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) + # register hooks on model/model_options + model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) diff --git a/comfy/samplers.py b/comfy/samplers.py index 94cba03b88f..b4c42160d3c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,11 +1,21 @@ +from __future__ import annotations from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.model_base import BaseModel + from comfy.controlnet import ControlBase import torch import collections from comfy import model_management import math import logging +import comfy.samplers import comfy.sampler_helpers +import comfy.model_patcher +import comfy.patcher_extension +import comfy.hooks import scipy.stats import numpy @@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in): for c in model_conds: conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + hooks = conds.get('hooks', None) control = conds.get('control', None) patches = None @@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in): patches['middle_patch'] = [gligen_patch] - cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) - return cond_obj(input_x, mult, conditioning, area, control, patches) + cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks']) + return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks) def cond_equal_size(c1, c2): if c1 is c2: @@ -138,110 +149,184 @@ def cond_cat(c_list): return out -def calc_cond_batch(model, conds, x_in, timestep, model_options): +def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep): + # need to figure out remaining unmasked area for conds + default_mults = [] + for _ in default_conds: + default_mults.append(torch.ones_like(x_in)) + # look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond + for lora_hooks, to_run in hooked_to_run.items(): + for cond_obj, i in to_run: + # if no default_cond for cond_type, do nothing + if len(default_conds[i]) == 0: + continue + area: list[int] = cond_obj.area + if area is not None: + curr_default_mult: torch.Tensor = default_mults[i] + dims = len(area) // 2 + for i in range(dims): + curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i]) + curr_default_mult -= cond_obj.mult + else: + default_mults[i] -= cond_obj.mult + # for each default_mult, ReLU to make negatives=0, and then check for any nonzeros + for i, mult in enumerate(default_mults): + # if no default_cond for cond type, do nothing + if len(default_conds[i]) == 0: + continue + torch.nn.functional.relu(mult, inplace=True) + # if mult is all zeros, then don't add default_cond + if torch.max(mult) == 0.0: + continue + + cond = default_conds[i] + for x in cond: + # do get_area_and_mult to get all the expected values + p = comfy.samplers.get_area_and_mult(x, x_in, timestep) + if p is None: + continue + # replace p's mult with calculated mult + p = p._replace(mult=mult) + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + +def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): + executor = comfy.patcher_extension.WrapperExecutor.new_executor( + _calc_cond_batch, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) + ) + return executor.execute(model, conds, x_in, timestep, model_options) + +def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): out_conds = [] out_counts = [] - to_run = [] + # separate conds by matching hooks + hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} + default_conds = [] + has_default_conds = False for i in range(len(conds)): out_conds.append(torch.zeros_like(x_in)) out_counts.append(torch.ones_like(x_in) * 1e-37) cond = conds[i] + default_c = [] if cond is not None: for x in cond: - p = get_area_and_mult(x, x_in, timestep) + if 'default' in x: + default_c.append(x) + has_default_conds = True + continue + p = comfy.samplers.get_area_and_mult(x, x_in, timestep) if p is None: continue + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + default_conds.append(default_c) + + if has_default_conds: + finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep) + + model.current_patcher.prepare_state(timestep) + + # run every hooked_to_run separately + for hooks, to_run in hooked_to_run.items(): + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = model_management.get_free_memory(x_in.device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) * 1.5 < free_memory: + to_batch = batch_amount + break + + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) + + transformer_options = model.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + # TODO: replace with merge_nested_dicts function + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches - to_run += [(p, i)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) * 1.5 < free_memory: - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x.append(p.input_x) - mult.append(p.mult) - c.append(p.conditioning) - area.append(p.area) - cond_or_uncond.append(o[1]) - control = p.control - patches = p.patches - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches - else: - transformer_options["patches"] = patches + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep + c['transformer_options'] = transformer_options - c['transformer_options'] = transformer_options + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - - for o in range(batch_chunks): - cond_index = cond_or_uncond[o] - a = area[o] - if a is None: - out_conds[cond_index] += output[o] * mult[o] - out_counts[cond_index] += mult[o] + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: - out_c = out_conds[cond_index] - out_cts = out_counts[cond_index] - dims = len(a) // 2 - for i in range(dims): - out_c = out_c.narrow(i + 2, a[i + dims], a[i]) - out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) - out_c += output[o] * mult[o] - out_cts += mult[o] + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -500,10 +585,15 @@ def calculate_start_end_timesteps(model, conds): timestep_start = None timestep_end = None - if 'start_percent' in x: - timestep_start = s.percent_to_sigma(x['start_percent']) - if 'end_percent' in x: - timestep_end = s.percent_to_sigma(x['end_percent']) + # handle clip hook schedule, if needed + if 'clip_start_percent' in x: + timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0))) + timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0))) + else: + if 'start_percent' in x: + timestep_start = s.percent_to_sigma(x['start_percent']) + if 'end_percent' in x: + timestep_end = s.percent_to_sigma(x['end_percent']) if (timestep_start is not None) or (timestep_end is not None): n = x.copy() @@ -673,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N if k != kk: create_cond_with_same_area_if_none(conds[kk], c) + for k in conds: + for c in conds[k]: + if 'hooks' in c: + for hook in c['hooks'].hooks: + hook.initialize_timesteps(model) + for k in conds: pre_run_control(model, conds[k]) @@ -685,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N return conds + +def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]): + # determine which ControlNets have extra_hooks that should be combined with normal hooks + hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {} + for k in conds: + for kk in conds[k]: + if 'control' in kk: + control: 'ControlBase' = kk['control'] + extra_hooks = control.get_extra_hooks() + if len(extra_hooks) > 0: + hooks: comfy.hooks.HookGroup = kk.get('hooks', None) + to_replace = hook_replacement.setdefault((control, hooks), []) + to_replace.append(kk) + # if nothing to replace, do nothing + if len(hook_replacement) == 0: + return + + # for optimal sampling performance, common ControlNets + hook combos should have identical hooks + # on the cond dicts + for key, conds_to_modify in hook_replacement.items(): + control = key[0] + hooks = key[1] + hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks]) + # if combined hooks are not None, set as new hooks for all relevant conds + if hooks is not None: + for cond in conds_to_modify: + cond['hooks'] = hooks + + +def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): + hooks_set = set() + for k in conds: + for kk in conds[k]: + hooks_set.add(kk.get('hooks', None)) + return len(hooks_set) + + class CFGGuider: def __init__(self, model_patcher): - self.model_patcher = model_patcher + self.model_patcher: 'ModelPatcher' = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -714,19 +847,17 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) - extra_args = {"model_options": self.model_options, "seed":seed} + extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed} - samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( + sampler.sample, + sampler, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True) + ) + samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - if sigmas.shape[-1] == 0: - return latent_image - - self.conds = {} - for k in self.original_conds: - self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) - + def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) device = self.model_patcher.load_device @@ -737,14 +868,48 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba latent_image = latent_image.to(device) sigmas = sigmas.to(device) - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + try: + self.model_patcher.pre_run() + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + finally: + self.model_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model - del self.conds del self.loaded_models return output + def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + if sigmas.shape[-1] == 0: + return latent_image + + self.conds = {} + for k in self.original_conds: + self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) + preprocess_conds_hooks(self.conds) + + try: + orig_model_options = self.model_options + self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options) + # if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step) + orig_hook_mode = self.model_patcher.hook_mode + if get_total_hook_groups_in_conds(self.conds) <= 1: + self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram + comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) + executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self.outer_sample, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) + ) + output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + finally: + self.model_options = orig_model_options + self.model_patcher.hook_mode = orig_hook_mode + self.model_patcher.restore_hook_patches() + + del self.conds + return output + def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): cfg_guider = CFGGuider(model) diff --git a/comfy/sd.py b/comfy/sd.py index e2af7078121..ebae7f99688 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,8 +1,10 @@ +from __future__ import annotations import torch from enum import Enum import logging from comfy import model_management +from comfy.utils import ProgressBar from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_c_coder import StageC_coder @@ -33,6 +35,7 @@ import comfy.model_patcher import comfy.lora import comfy.lora_convert +import comfy.hooks import comfy.t2i_adapter.adapter import comfy.taesd.taesd @@ -98,9 +101,13 @@ def __init__(self, target=None, embedding_directory=None, no_init=False, tokeniz self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram + self.patcher.is_clip = True + self.apply_hooks_to_conds = None if params['device'] == load_device: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None + self.use_clip_schedule = False logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) def clone(self): @@ -109,6 +116,8 @@ def clone(self): n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx + n.use_clip_schedule = self.use_clip_schedule + n.apply_hooks_to_conds = self.apply_hooks_to_conds return n def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): @@ -120,6 +129,69 @@ def clip_layer(self, layer_idx): def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) + def add_hooks_to_dict(self, pooled_dict: dict[str]): + if self.apply_hooks_to_conds: + pooled_dict["hooks"] = self.apply_hooks_to_conds + return pooled_dict + + def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True): + all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = [] + all_hooks = self.patcher.forced_hooks + if all_hooks is None or not self.use_clip_schedule: + # if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict + return_pooled = "unprojected" if unprojected else True + pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True) + cond = pooled_dict.pop("cond") + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + all_cond_pooled.append([cond, pooled_dict]) + else: + scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule() + + self.cond_stage_model.reset_clip_options() + if self.layer_idx is not None: + self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) + if unprojected: + self.cond_stage_model.set_clip_options({"projected_pooled": False}) + + self.load_model() + all_hooks.reset() + self.patcher.patch_hooks(None) + if show_pbar: + pbar = ProgressBar(len(scheduled_keyframes)) + + for scheduled_opts in scheduled_keyframes: + t_range = scheduled_opts[0] + # don't bother encoding any conds outside of start_percent and end_percent bounds + if "start_percent" in add_dict: + if t_range[1] < add_dict["start_percent"]: + continue + if "end_percent" in add_dict: + if t_range[0] > add_dict["end_percent"]: + continue + hooks_keyframes = scheduled_opts[1] + for hook, keyframe in hooks_keyframes: + hook.hook_keyframe._current_keyframe = keyframe + # apply appropriate hooks with values that match new hook_keyframe + self.patcher.patch_hooks(all_hooks) + # perform encoding as normal + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + pooled_dict = {"pooled_output": pooled} + # add clip_start_percent and clip_end_percent in pooled + pooled_dict["clip_start_percent"] = t_range[0] + pooled_dict["clip_end_percent"] = t_range[1] + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + # add hooks stored on clip + self.add_hooks_to_dict(pooled_dict) + all_cond_pooled.append([cond, pooled_dict]) + if show_pbar: + pbar.update(1) + model_management.throw_exception_if_processing_interrupted() + all_hooks.reset() + return all_cond_pooled + def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): self.cond_stage_model.reset_clip_options() @@ -137,6 +209,7 @@ def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): if len(o) > 2: for k in o[2]: out[k] = o[2][k] + self.add_hooks_to_dict(out) return out if return_pooled: diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index 3087b917b41..b8e241578e7 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -17,8 +17,7 @@ def INPUT_TYPES(s): def encode(self, clip, ascore, width, height, text): tokens = clip.tokenize(text) - cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), ) class CLIPTextEncodeSDXL: @classmethod @@ -47,8 +46,7 @@ def encode(self, clip, width, height, crop_w, crop_h, target_width, target_heigh tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), ) NODE_CLASS_MAPPINGS = { "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index b690432b55b..2ae23f73550 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -18,10 +18,7 @@ def encode(self, clip, clip_l, t5xxl, guidance): tokens = clip.tokenize(clip_l) tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] - output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) - cond = output.pop("cond") - output["guidance"] = guidance - return ([[cond, output]], ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) class FluxGuidance: @classmethod diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py new file mode 100644 index 00000000000..7b5344e5eac --- /dev/null +++ b/comfy_extras/nodes_hooks.py @@ -0,0 +1,697 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Union +import torch +from collections.abc import Iterable + +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.sd import CLIP + +import comfy.hooks +import comfy.sd +import comfy.utils +import folder_paths + +########################################### +# Mask, Combine, and Hook Conditioning +#------------------------------------------ +class PairConditioningSetProperties: + NodeId = 'PairConditioningSetProperties' + NodeName = 'Cond Pair Set Props' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive_NEW": ("CONDITIONING", ), + "negative_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_properties" + + def set_properties(self, positive_NEW, negative_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_positive, final_negative) + +class PairConditioningSetPropertiesAndCombine: + NodeId = 'PairConditioningSetPropertiesAndCombine' + NodeName = 'Cond Pair Set Props Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "positive_NEW": ("CONDITIONING", ), + "negative_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_properties" + + def set_properties(self, positive, negative, positive_NEW, negative_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_positive, final_negative) + +class ConditioningSetProperties: + NodeId = 'ConditioningSetProperties' + NodeName = 'Cond Set Props' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_properties" + + def set_properties(self, cond_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + (final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_cond,) + +class ConditioningSetPropertiesAndCombine: + NodeId = 'ConditioningSetPropertiesAndCombine' + NodeName = 'Cond Set Props Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond": ("CONDITIONING", ), + "cond_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_properties" + + def set_properties(self, cond, cond_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + (final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_cond,) + +class PairConditioningCombine: + NodeId = 'PairConditioningCombine' + NodeName = 'Cond Pair Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive_A": ("CONDITIONING",), + "negative_A": ("CONDITIONING",), + "positive_B": ("CONDITIONING",), + "negative_B": ("CONDITIONING",), + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "combine" + + def combine(self, positive_A, negative_A, positive_B, negative_B): + final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) + return (final_positive, final_negative,) + +class PairConditioningSetDefaultAndCombine: + NodeId = 'PairConditioningSetDefaultCombine' + NodeName = 'Cond Pair Set Default Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "positive_DEFAULT": ("CONDITIONING",), + "negative_DEFAULT": ("CONDITIONING",), + }, + "optional": { + "hooks": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_default_and_combine" + + def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, + hooks: comfy.hooks.HookGroup=None): + final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], + hooks=hooks) + return (final_positive, final_negative) + +class ConditioningSetDefaultAndCombine: + NodeId = 'ConditioningSetDefaultCombine' + NodeName = 'Cond Set Default Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond": ("CONDITIONING",), + "cond_DEFAULT": ("CONDITIONING",), + }, + "optional": { + "hooks": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_default_and_combine" + + def set_default_and_combine(self, cond, cond_DEFAULT, + hooks: comfy.hooks.HookGroup=None): + (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT], + hooks=hooks) + return (final_conditioning,) + +class SetClipHooks: + NodeId = 'SetClipHooks' + NodeName = 'Set CLIP Hooks' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "clip": ("CLIP",), + "apply_to_conds": ("BOOLEAN", {"default": True}), + "schedule_clip": ("BOOLEAN", {"default": False}) + }, + "optional": { + "hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CLIP",) + CATEGORY = "advanced/hooks/clip" + FUNCTION = "apply_hooks" + + def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): + if hooks is not None: + clip = clip.clone() + if apply_to_conds: + clip.apply_hooks_to_conds = hooks + clip.patcher.forced_hooks = hooks.clone() + clip.use_clip_schedule = schedule_clip + if not clip.use_clip_schedule: + clip.patcher.forced_hooks.set_keyframes_on_hooks(None) + clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) + return (clip,) + +class ConditioningTimestepsRange: + NodeId = 'ConditioningTimestepsRange' + NodeName = 'Timesteps Range' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE") + RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE") + CATEGORY = "advanced/hooks" + FUNCTION = "create_range" + + def create_range(self, start_percent: float, end_percent: float): + return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0)) +#------------------------------------------ +########################################### + + +########################################### +# Create Hooks +#------------------------------------------ +class CreateHookLora: + NodeId = 'CreateHookLora' + NodeName = 'Create Hook LoRA' + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook" + + def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None): + if prev_hooks is None: + prev_hooks = comfy.hooks.HookGroup() + prev_hooks.clone() + + if strength_model == 0 and strength_clip == 0: + return (prev_hooks,) + + lora_path = folder_paths.get_full_path("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + temp = self.loaded_lora + self.loaded_lora = None + del temp + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip) + return (prev_hooks.clone_and_combine(hooks),) + +class CreateHookLoraModelOnly(CreateHookLora): + NodeId = 'CreateHookLoraModelOnly' + NodeName = 'Create Hook LoRA (MO)' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook_model_only" + + def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None): + return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks) + +class CreateHookModelAsLora: + NodeId = 'CreateHookModelAsLora' + NodeName = 'Create Hook Model as LoRA' + + def __init__(self): + # when not None, will be in following format: + # (ckpt_path: str, weights_model: dict, weights_clip: dict) + self.loaded_weights = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook" + + def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float, + prev_hooks: comfy.hooks.HookGroup=None): + if prev_hooks is None: + prev_hooks = comfy.hooks.HookGroup() + prev_hooks.clone() + + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + weights_model = None + weights_clip = None + if self.loaded_weights is not None: + if self.loaded_weights[0] == ckpt_path: + weights_model = self.loaded_weights[1] + weights_clip = self.loaded_weights[2] + else: + temp = self.loaded_weights + self.loaded_weights = None + del temp + + if weights_model is None: + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + weights_model = comfy.hooks.get_patch_weights_from_model(out[0]) + weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1]) + self.loaded_weights = (ckpt_path, weights_model, weights_clip) + + hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip, + strength_model=strength_model, strength_clip=strength_clip) + return (prev_hooks.clone_and_combine(hooks),) + +class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora): + NodeId = 'CreateHookModelAsLoraModelOnly' + NodeName = 'Create Hook Model as LoRA (MO)' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook_model_only" + + def create_hook_model_only(self, ckpt_name: str, strength_model: float, + prev_hooks: comfy.hooks.HookGroup=None): + return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks) +#------------------------------------------ +########################################### + + +########################################### +# Schedule Hooks +#------------------------------------------ +class SetHookKeyframes: + NodeId = 'SetHookKeyframes' + NodeName = 'Set Hook Keyframes' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "hooks": ("HOOKS",), + }, + "optional": { + "hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "set_hook_keyframes" + + def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None): + if hook_kf is not None: + hooks = hooks.clone() + hooks.set_keyframes_on_hooks(hook_kf=hook_kf) + return (hooks,) + +class CreateHookKeyframe: + NodeId = 'CreateHookKeyframe' + NodeName = 'Create Hook Keyframe' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "optional": { + "prev_hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOK_KEYFRAMES",) + RETURN_NAMES = ("HOOK_KF",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "create_hook_keyframe" + + def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None): + if prev_hook_kf is None: + prev_hook_kf = comfy.hooks.HookKeyframeGroup() + prev_hook_kf = prev_hook_kf.clone() + keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent) + prev_hook_kf.add(keyframe) + return (prev_hook_kf,) + +class CreateHookKeyframesFromFloats: + NodeId = 'CreateHookKeyframesFromFloats' + NodeName = 'Create Hook Keyframes From Floats' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOK_KEYFRAMES",) + RETURN_NAMES = ("HOOK_KF",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "create_hook_keyframes" + + def create_hook_keyframes(self, floats_strength: Union[float, list[float]], + start_percent: float, end_percent: float, + prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False): + if prev_hook_kf is None: + prev_hook_kf = comfy.hooks.HookKeyframeGroup() + prev_hook_kf = prev_hook_kf.clone() + if type(floats_strength) in (float, int): + floats_strength = [float(floats_strength)] + elif isinstance(floats_strength, Iterable): + pass + else: + raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.") + percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength), + method=comfy.hooks.InterpolationMethod.LINEAR) + + is_first = True + for percent, strength in zip(percents, floats_strength): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) + if print_keyframes: + print(f"Hook Keyframe - start_percent:{percent} = {strength}") + return (prev_hook_kf,) +#------------------------------------------ +########################################### + + +class SetModelHooksOnCond: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "conditioning": ("CONDITIONING",), + "hooks": ("HOOKS",), + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/manual" + FUNCTION = "attach_hook" + + def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup): + return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),) + + +########################################### +# Combine Hooks +#------------------------------------------ +class CombineHooks: + NodeId = 'CombineHooks2' + NodeName = 'Combine Hooks [2]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) + +class CombineHooksFour: + NodeId = 'CombineHooks4' + NodeName = 'Combine Hooks [4]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + "hooks_C": ("HOOKS",), + "hooks_D": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None, + hooks_C: comfy.hooks.HookGroup=None, + hooks_D: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B, hooks_C, hooks_D] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) + +class CombineHooksEight: + NodeId = 'CombineHooks8' + NodeName = 'Combine Hooks [8]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + "hooks_C": ("HOOKS",), + "hooks_D": ("HOOKS",), + "hooks_E": ("HOOKS",), + "hooks_F": ("HOOKS",), + "hooks_G": ("HOOKS",), + "hooks_H": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None, + hooks_C: comfy.hooks.HookGroup=None, + hooks_D: comfy.hooks.HookGroup=None, + hooks_E: comfy.hooks.HookGroup=None, + hooks_F: comfy.hooks.HookGroup=None, + hooks_G: comfy.hooks.HookGroup=None, + hooks_H: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) +#------------------------------------------ +########################################### + +node_list = [ + # Create + CreateHookLora, + CreateHookLoraModelOnly, + CreateHookModelAsLora, + CreateHookModelAsLoraModelOnly, + # Scheduling + SetHookKeyframes, + CreateHookKeyframe, + CreateHookKeyframesFromFloats, + # Combine + CombineHooks, + CombineHooksFour, + CombineHooksEight, + # Attach + ConditioningSetProperties, + ConditioningSetPropertiesAndCombine, + PairConditioningSetProperties, + PairConditioningSetPropertiesAndCombine, + ConditioningSetDefaultAndCombine, + PairConditioningSetDefaultAndCombine, + PairConditioningCombine, + SetClipHooks, + # Other + ConditioningTimestepsRange, +] +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +for node in node_list: + NODE_CLASS_MAPPINGS[node.NodeId] = node + NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index b03eaf6a204..2bd295e2459 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -15,9 +15,7 @@ def encode(self, clip, bert, mt5xl): tokens = clip.tokenize(bert) tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] - output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) - cond = output.pop("cond") - return ([[cond, output]], ) + return (clip.encode_from_tokens_scheduled(tokens), ) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 6ef3c293d98..d75b29e606f 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -82,8 +82,7 @@ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding): tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled}]], ) + return (clip.encode_from_tokens_scheduled(tokens), ) class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): diff --git a/nodes.py b/nodes.py index fb504da3543..260bb5e153f 100644 --- a/nodes.py +++ b/nodes.py @@ -62,9 +62,8 @@ def INPUT_TYPES(s): def encode(self, clip, text): tokens = clip.tokenize(text) - output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) - cond = output.pop("cond") - return ([[cond, output]], ) + return (clip.encode_from_tokens_scheduled(tokens), ) + class ConditioningCombine: @classmethod @@ -2149,6 +2148,7 @@ def init_builtin_extra_nodes(): "nodes_mochi.py", "nodes_slg.py", "nodes_lt.py", + "nodes_hooks.py", ] import_failed = []