From 0ee322ec5f338791c5836b79830e2f419d6fcc79 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 2 Dec 2024 13:51:02 -0600 Subject: [PATCH] ModelPatcher Overhaul and Hook Support (#5583) * Added hook_patches to ModelPatcher for weights (model) * Initial changes to calc_cond_batch to eventually support hook_patches * Added current_patcher property to BaseModel * Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature * Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch * Added default_conds support in calc_cond_batch func * Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring * Made CLIP work with hook patches * Added initial hook scheduling nodes, small renaming/refactoring * Fixed MaxSpeed and default conds implementations * Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time * Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node * Initial work on adding 'model_as_lora' lora type to calculate_weight * Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models * Fix incorrect ref to create_hook_patches_clone after moving function * Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks * Added wrappers to ModelPatcher to facilitate standardized function wrapping * Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type * Fix skip_until_exit logic bug breaking injection after first run of model * Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load * Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers * Refactored callbacks+wrappers to allow storing lists by id * Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks * Added get_attachment func on ModelPatcher * Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup * Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref * Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations * Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE) * Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo) * Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending) * Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling * Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances * Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks * Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge * Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes * Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges * Optimized CLIP hook scheduling to treat same strength as same keyframe * Less fragile memory management. * Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict * Fix issue. * Remove useless function. * Prevent and detect some types of memory leaks. * Run garbage collector when switching workflow if needed. * Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py * Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers * Fix issue. * Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming * Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options * Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired * Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality * Fixed existing weight hook_patches (pre-registered) not working properly for CLIP * Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky * Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3) * Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher * Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate * Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True * Fix cached_hook_patches not respecting target_device/memory_counter results * Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches * Remove unnecessary torch.no_grad calls for hook patches * Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists * For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes * Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call * Made encode_from_tokens_scheduled work when no hooks are set on patcher * Small cleanup of comments * Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case * On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs * Allow both FLOATS and FLOAT for floats_strength input * Revert change, does not work * Made patch_hook_weight_to_device respect set_func and convert_func * Make discard_model_sampling True by default * Add changes manually from 'master' so merge conflict resolution goes more smoothly * Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call * Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip * Made nodes in nodes_hooks be marked as experimental (beta) * Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references * Made finalize_default_conds area math consistent with other sampling code * Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks' * Remove a couple old TODO's and a no longer necessary workaround --- comfy/controlnet.py | 22 +- comfy/hooks.py | 690 ++++++++++++++ .../modules/diffusionmodules/openaimodel.py | 17 + comfy/lora.py | 16 +- comfy/model_base.py | 12 + comfy/model_patcher.py | 870 ++++++++++++++---- comfy/patcher_extension.py | 156 ++++ comfy/sampler_helpers.py | 72 +- comfy/samplers.py | 377 +++++--- comfy/sd.py | 73 ++ comfy_extras/nodes_clip_sdxl.py | 6 +- comfy_extras/nodes_flux.py | 5 +- comfy_extras/nodes_hooks.py | 697 ++++++++++++++ comfy_extras/nodes_hunyuan.py | 4 +- comfy_extras/nodes_sd3.py | 3 +- nodes.py | 6 +- 16 files changed, 2707 insertions(+), 319 deletions(-) create mode 100644 comfy/hooks.py create mode 100644 comfy/patcher_extension.py create mode 100644 comfy_extras/nodes_hooks.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index a44f3725e80..e6a0d1e5976 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,6 +36,10 @@ import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.cldm.dit_embedder +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.hooks import HookGroup + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -78,6 +82,7 @@ def __init__(self): self.concat_mask = False self.extra_concat_orig = [] self.extra_concat = None + self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): @@ -115,6 +120,14 @@ def get_models(self): if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out + + def get_extra_hooks(self): + out = [] + if self.extra_hooks is not None: + out.append(self.extra_hooks) + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_extra_hooks() + return out def copy_to(self, c): c.cond_hint_original = self.cond_hint_original @@ -130,6 +143,7 @@ def copy_to(self, c): c.strength_type = self.strength_type c.concat_mask = self.concat_mask c.extra_concat_orig = self.extra_concat_orig.copy() + c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None c.preprocess_image = self.preprocess_image def inference_memory_requirements(self, dtype): @@ -200,10 +214,10 @@ def __init__(self, control_model=None, global_average_pooling=False, compression self.concat_mask = concat_mask self.preprocess_image = preprocess_image - def get_control(self, x_noisy, t, cond, batched_number): + def get_control(self, x_noisy, t, cond, batched_number, transformer_options): control_prev = None if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: @@ -758,10 +772,10 @@ def scale_image_to(self, width, height): height = math.ceil(height / unshuffle_amount) * unshuffle_amount return width, height - def get_control(self, x_noisy, t, cond, batched_number): + def get_control(self, x_noisy, t, cond, batched_number, transformer_options): control_prev = None if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: diff --git a/comfy/hooks.py b/comfy/hooks.py new file mode 100644 index 00000000000..ccb8183b91d --- /dev/null +++ b/comfy/hooks.py @@ -0,0 +1,690 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Callable +import enum +import math +import torch +import numpy as np +import itertools + +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher, PatcherInjection + from comfy.model_base import BaseModel + from comfy.sd import CLIP +import comfy.lora +import comfy.model_management +import comfy.patcher_extension +from node_helpers import conditioning_set_values + +class EnumHookMode(enum.Enum): + MinVram = "minvram" + MaxSpeed = "maxspeed" + +class EnumHookType(enum.Enum): + Weight = "weight" + Patch = "patch" + ObjectPatch = "object_patch" + AddModels = "add_models" + Callbacks = "callbacks" + Wrappers = "wrappers" + SetInjections = "add_injections" + +class EnumWeightTarget(enum.Enum): + Model = "model" + Clip = "clip" + +class _HookRef: + pass + +# NOTE: this is an example of how the should_register function should look +def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + return True + + +class Hook: + def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, + hook_keyframe: 'HookKeyframeGroup'=None): + self.hook_type = hook_type + self.hook_ref = hook_ref if hook_ref else _HookRef() + self.hook_id = hook_id + self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + self.custom_should_register = default_should_register + self.auto_apply_to_nonpositive = False + + @property + def strength(self): + return self.hook_keyframe.strength + + def initialize_timesteps(self, model: 'BaseModel'): + self.reset() + self.hook_keyframe.initialize_timesteps(model) + + def reset(self): + self.hook_keyframe.reset() + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: Hook = subtype() + c.hook_type = self.hook_type + c.hook_ref = self.hook_ref + c.hook_id = self.hook_id + c.hook_keyframe = self.hook_keyframe + c.custom_should_register = self.custom_should_register + # TODO: make this do something + c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive + return c + + def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + return self.custom_should_register(self, model, model_options, target, registered) + + def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") + + def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): + pass + + def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): + pass + + def __eq__(self, other: 'Hook'): + return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref + + def __hash__(self): + return hash(self.hook_ref) + +class WeightHook(Hook): + def __init__(self, strength_model=1.0, strength_clip=1.0): + super().__init__(hook_type=EnumHookType.Weight) + self.weights: dict = None + self.weights_clip: dict = None + self.need_weight_init = True + self._strength_model = strength_model + self._strength_clip = strength_clip + + @property + def strength_model(self): + return self._strength_model * self.strength + + @property + def strength_clip(self): + return self._strength_clip * self.strength + + def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + if not self.should_register(model, model_options, target, registered): + return False + weights = None + if target == EnumWeightTarget.Model: + strength = self._strength_model + else: + strength = self._strength_clip + + if self.need_weight_init: + key_map = {} + if target == EnumWeightTarget.Model: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + else: + key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) + weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) + else: + if target == EnumWeightTarget.Model: + weights = self.weights + else: + weights = self.weights_clip + k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) + registered.append(self) + return True + # TODO: add logs about any keys that were not applied + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: WeightHook = super().clone(subtype) + c.weights = self.weights + c.weights_clip = self.weights_clip + c.need_weight_init = self.need_weight_init + c._strength_model = self._strength_model + c._strength_clip = self._strength_clip + return c + +class PatchHook(Hook): + def __init__(self): + super().__init__(hook_type=EnumHookType.Patch) + self.patches: dict = None + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: PatchHook = super().clone(subtype) + c.patches = self.patches + return c + # TODO: add functionality + +class ObjectPatchHook(Hook): + def __init__(self): + super().__init__(hook_type=EnumHookType.ObjectPatch) + self.object_patches: dict = None + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: ObjectPatchHook = super().clone(subtype) + c.object_patches = self.object_patches + return c + # TODO: add functionality + +class AddModelsHook(Hook): + def __init__(self, key: str=None, models: list['ModelPatcher']=None): + super().__init__(hook_type=EnumHookType.AddModels) + self.key = key + self.models = models + self.append_when_same = True + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: AddModelsHook = super().clone(subtype) + c.key = self.key + c.models = self.models.copy() if self.models else self.models + c.append_when_same = self.append_when_same + return c + # TODO: add functionality + +class CallbackHook(Hook): + def __init__(self, key: str=None, callback: Callable=None): + super().__init__(hook_type=EnumHookType.Callbacks) + self.key = key + self.callback = callback + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: CallbackHook = super().clone(subtype) + c.key = self.key + c.callback = self.callback + return c + # TODO: add functionality + +class WrapperHook(Hook): + def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): + super().__init__(hook_type=EnumHookType.Wrappers) + self.wrappers_dict = wrappers_dict + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: WrapperHook = super().clone(subtype) + c.wrappers_dict = self.wrappers_dict + return c + + def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + if not self.should_register(model, model_options, target, registered): + return False + add_model_options = {"transformer_options": self.wrappers_dict} + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + registered.append(self) + return True + +class SetInjectionsHook(Hook): + def __init__(self, key: str=None, injections: list['PatcherInjection']=None): + super().__init__(hook_type=EnumHookType.SetInjections) + self.key = key + self.injections = injections + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: SetInjectionsHook = super().clone(subtype) + c.key = self.key + c.injections = self.injections.copy() if self.injections else self.injections + return c + + def add_hook_injections(self, model: 'ModelPatcher'): + # TODO: add functionality + pass + +class HookGroup: + def __init__(self): + self.hooks: list[Hook] = [] + + def add(self, hook: Hook): + if hook not in self.hooks: + self.hooks.append(hook) + + def contains(self, hook: Hook): + return hook in self.hooks + + def clone(self): + c = HookGroup() + for hook in self.hooks: + c.add(hook.clone()) + return c + + def clone_and_combine(self, other: 'HookGroup'): + c = self.clone() + if other is not None: + for hook in other.hooks: + c.add(hook.clone()) + return c + + def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): + if hook_kf is None: + hook_kf = HookKeyframeGroup() + else: + hook_kf = hook_kf.clone() + for hook in self.hooks: + hook.hook_keyframe = hook_kf + + def get_dict_repr(self): + d: dict[EnumHookType, dict[Hook, None]] = {} + for hook in self.hooks: + with_type = d.setdefault(hook.hook_type, {}) + with_type[hook] = None + return d + + def get_hooks_for_clip_schedule(self): + scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} + for hook in self.hooks: + # only care about WeightHooks, for now + if hook.hook_type == EnumHookType.Weight: + hook_schedule = [] + # if no hook keyframes, assign default value + if len(hook.hook_keyframe.keyframes) == 0: + hook_schedule.append(((0.0, 1.0), None)) + scheduled_hooks[hook] = hook_schedule + continue + # find ranges of values + prev_keyframe = hook.hook_keyframe.keyframes[0] + for keyframe in hook.hook_keyframe.keyframes: + if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): + hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) + prev_keyframe = keyframe + elif keyframe.start_percent == prev_keyframe.start_percent: + prev_keyframe = keyframe + # create final range, assuming last start_percent was not 1.0 + if not math.isclose(prev_keyframe.start_percent, 1.0): + hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + scheduled_hooks[hook] = hook_schedule + # hooks should not have their schedules in a list of tuples + all_ranges: list[tuple[float, float]] = [] + for range_kfs in scheduled_hooks.values(): + for t_range, keyframe in range_kfs: + all_ranges.append(t_range) + # turn list of ranges into boundaries + boundaries_set = set(itertools.chain.from_iterable(all_ranges)) + boundaries_set.add(0.0) + boundaries = sorted(boundaries_set) + real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)] + # with real ranges defined, give appropriate hooks w/ keyframes for each range + scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = [] + for t_range in real_ranges: + hooks_schedule = [] + for hook, val in scheduled_hooks.items(): + keyframe = None + # check if is a keyframe that works for the current t_range + for stored_range, stored_kf in val: + # if stored start is less than current end, then fits - give it assigned keyframe + if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]: + keyframe = stored_kf + break + hooks_schedule.append((hook, keyframe)) + scheduled_keyframes.append((t_range, hooks_schedule)) + return scheduled_keyframes + + def reset(self): + for hook in self.hooks: + hook.reset() + + @staticmethod + def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': + actual: list[HookGroup] = [] + for group in hooks_list: + if group is not None: + actual.append(group) + if len(actual) < require_count: + raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.") + # if no hooks, then return None + if len(actual) == 0: + return None + # if only 1 hook, just return itself without cloning + elif len(actual) == 1: + return actual[0] + final_hook: HookGroup = None + for hook in actual: + if final_hook is None: + final_hook = hook.clone() + else: + final_hook = final_hook.clone_and_combine(hook) + return final_hook + + +class HookKeyframe: + def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1): + self.strength = strength + # scheduling + self.start_percent = float(start_percent) + self.start_t = 999999999.9 + self.guarantee_steps = guarantee_steps + + def clone(self): + c = HookKeyframe(strength=self.strength, + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) + c.start_t = self.start_t + return c + +class HookKeyframeGroup: + def __init__(self): + self.keyframes: list[HookKeyframe] = [] + self._current_keyframe: HookKeyframe = None + self._current_used_steps = 0 + self._current_index = 0 + self._current_strength = None + self._curr_t = -1. + + # properties shadow those of HookWeightsKeyframe + @property + def strength(self): + if self._current_keyframe is not None: + return self._current_keyframe.strength + return 1.0 + + def reset(self): + self._current_keyframe = None + self._current_used_steps = 0 + self._current_index = 0 + self._current_strength = None + self.curr_t = -1. + self._set_first_as_current() + + def add(self, keyframe: HookKeyframe): + # add to end of list, then sort + self.keyframes.append(keyframe) + self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") + self._set_first_as_current() + + def _set_first_as_current(self): + if len(self.keyframes) > 0: + self._current_keyframe = self.keyframes[0] + else: + self._current_keyframe = None + + def has_index(self, index: int): + return index >= 0 and index < len(self.keyframes) + + def is_empty(self): + return len(self.keyframes) == 0 + + def clone(self): + c = HookKeyframeGroup() + for keyframe in self.keyframes: + c.keyframes.append(keyframe.clone()) + c._set_first_as_current() + return c + + def initialize_timesteps(self, model: 'BaseModel'): + for keyframe in self.keyframes: + keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) + + def prepare_current_keyframe(self, curr_t: float) -> bool: + if self.is_empty(): + return False + if curr_t == self._curr_t: + return False + prev_index = self._current_index + prev_strength = self._current_strength + # if met guaranteed steps, look for next keyframe in case need to switch + if self._current_used_steps >= self._current_keyframe.guarantee_steps: + # if has next index, loop through and see if need to switch + if self.has_index(self._current_index+1): + for i in range(self._current_index+1, len(self.keyframes)): + eval_c = self.keyframes[i] + # check if start_t is greater or equal to curr_t + # NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling + if eval_c.start_t >= curr_t: + self._current_index = i + self._current_strength = eval_c.strength + self._current_keyframe = eval_c + self._current_used_steps = 0 + # if guarantee_steps greater than zero, stop searching for other keyframes + if self._current_keyframe.guarantee_steps > 0: + break + # if eval_c is outside the percent range, stop looking further + else: break + # update steps current context is used + self._current_used_steps += 1 + # update current timestep this was performed on + self._curr_t = curr_t + # return True if keyframe changed, False if no change + return prev_index != self._current_index and prev_strength != self._current_strength + + +class InterpolationMethod: + LINEAR = "linear" + EASE_IN = "ease_in" + EASE_OUT = "ease_out" + EASE_IN_OUT = "ease_in_out" + + _LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT] + + @classmethod + def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False): + diff = num_to - num_from + if method == cls.LINEAR: + weights = torch.linspace(num_from, num_to, length) + elif method == cls.EASE_IN: + index = torch.linspace(0, 1, length) + weights = diff * np.power(index, 2) + num_from + elif method == cls.EASE_OUT: + index = torch.linspace(0, 1, length) + weights = diff * (1 - np.power(1 - index, 2)) + num_from + elif method == cls.EASE_IN_OUT: + index = torch.linspace(0, 1, length) + weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from + else: + raise ValueError(f"Unrecognized interpolation method '{method}'.") + if reverse: + weights = weights.flip(dims=(0,)) + return weights + +def get_sorted_list_via_attr(objects: list, attr: str) -> list: + if not objects: + return objects + elif len(objects) <= 1: + return [x for x in objects] + # now that we know we have to sort, do it following these rules: + # a) if objects have same value of attribute, maintain their relative order + # b) perform sorting of the groups of objects with same attributes + unique_attrs = {} + for o in objects: + val_attr = getattr(o, attr) + attr_list: list = unique_attrs.get(val_attr, list()) + attr_list.append(o) + if val_attr not in unique_attrs: + unique_attrs[val_attr] = attr_list + # now that we have the unique attr values grouped together in relative order, sort them by key + sorted_attrs = dict(sorted(unique_attrs.items())) + # now flatten out the dict into a list to return + sorted_list = [] + for object_list in sorted_attrs.values(): + sorted_list.extend(object_list) + return sorted_list + +def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): + hook_group = HookGroup() + hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) + hook_group.add(hook) + hook.weights = lora + return hook_group + +def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float): + hook_group = HookGroup() + hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) + hook_group.add(hook) + patches_model = None + patches_clip = None + if weights_model is not None: + patches_model = {} + for key in weights_model: + patches_model[key] = ("model_as_lora", (weights_model[key],)) + if weights_clip is not None: + patches_clip = {} + for key in weights_clip: + patches_clip[key] = ("model_as_lora", (weights_clip[key],)) + hook.weights = patches_model + hook.weights_clip = patches_clip + hook.need_weight_init = False + return hook_group + +def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): + if model is None: + return None + patches_model: dict[str, torch.Tensor] = model.model.state_dict() + if discard_model_sampling: + # do not include ANY model_sampling components of the model that should act as a patch + for key in list(patches_model.keys()): + if key.startswith("model_sampling"): + patches_model.pop(key, None) + return patches_model + +# NOTE: this function shows how to register weight hooks directly on the ModelPatchers +def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], + strength_model: float, strength_clip: float): + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + hook_group = HookGroup() + hook = WeightHook() + hook_group.add(hook) + loaded: dict[str] = comfy.lora.load_lora(lora, key_map) + if model is not None: + new_modelpatcher = model.clone() + k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model) + else: + k = () + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip) + else: + k1 = () + new_clip = None + k = set(k) + k1 = set(k1) + for x in loaded: + if (x not in k) and (x not in k1): + print(f"NOT LOADED {x}") + return (new_modelpatcher, new_clip, hook_group) + +def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]): + hooks_key = 'hooks' + # if hooks only exist in one dict, do what's needed so that it ends up in c_dict + if hooks_key not in values: + return + if hooks_key not in c_dict: + hooks_value = values.get(hooks_key, None) + if hooks_value is not None: + c_dict[hooks_key] = hooks_value + return + # otherwise, need to combine with minimum duplication via cache + hooks_tuple = (c_dict[hooks_key], values[hooks_key]) + cached_hooks = cache.get(hooks_tuple, None) + if cached_hooks is None: + new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1]) + cache[hooks_tuple] = new_hooks + c_dict[hooks_key] = new_hooks + else: + c_dict[hooks_key] = cache[hooks_tuple] + +def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): + c = [] + hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} + for t in conditioning: + n = [t[0], t[1].copy()] + for k in values: + if append_hooks and k == 'hooks': + _combine_hooks_from_values(n[1], values, hooks_combine_cache) + else: + n[1][k] = values[k] + c.append(n) + + return c + +def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): + if hooks is None: + return cond + return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) + +def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): + if timestep_range is None: + return cond + return conditioning_set_values(cond, {"start_percent": timestep_range[0], + "end_percent": timestep_range[1]}) + +def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float): + if mask is None: + return cond + set_area_to_bounds = False + if set_cond_area != 'default': + set_area_to_bounds = True + if len(mask.shape) < 3: + mask = mask.unsqueeze(0) + return conditioning_set_values(cond, {'mask': mask, + 'set_area_to_bounds': set_area_to_bounds, + 'mask_strength': strength}) + +def combine_conditioning(conds: list): + combined_conds = [] + for cond in conds: + combined_conds.extend(cond) + return combined_conds + +def combine_with_new_conds(conds: list, new_conds: list): + combined_conds = [] + for c, new_c in zip(conds, new_conds): + combined_conds.append(combine_conditioning([c, new_c])) + return combined_conds + +def set_conds_props(conds: list, strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): + final_conds = [] + for c in conds: + # first, apply lora_hook to conditioning, if provided + c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) + # next, apply mask to conditioning + c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) + # apply timesteps, if present + c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range) + # finally, apply mask to conditioning and store + final_conds.append(c) + return final_conds + +def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", + mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): + combined_conds = [] + for c, masked_c in zip(conds, new_conds): + # first, apply lora_hook to new conditioning, if provided + masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) + # next, apply mask to new conditioning, if provided + masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) + # apply timesteps, if present + masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range) + # finally, combine with existing conditioning and store + combined_conds.append(combine_conditioning([c, masked_c])) + return combined_conds + +def set_default_conds_and_combine(conds: list, new_conds: list, + hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): + combined_conds = [] + for c, new_c in zip(conds, new_conds): + # first, apply lora_hook to new conditioning, if provided + new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) + # next, add default_cond key to cond so that during sampling, it can be identified + new_c = conditioning_set_values(new_c, {'default': True}) + # apply timesteps, if present + new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range) + # finally, combine with existing conditioning and store + combined_conds.append(combine_conditioning([c, new_c])) + return combined_conds diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 2902073d5ea..3f7fee708ff 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -15,6 +15,7 @@ ) from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists +import comfy.patcher_extension import comfy.ops ops = comfy.ops.disable_weight_init @@ -47,6 +48,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: + if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: + found_patched = False + for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: + if isinstance(layer, class_type): + x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) + found_patched = True + break + if found_patched: + continue x = layer(x) return x @@ -819,6 +829,13 @@ def get_resblock( ) def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timesteps, context, y, control, transformer_options, **kwargs) + + def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. diff --git a/comfy/lora.py b/comfy/lora.py index 1080169b191..b6d9a8d04f8 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -33,7 +33,7 @@ } -def load_lora(lora, to_load): +def load_lora(lora, to_load, log_missing=True): patch_dict = {} loaded_keys = set() for x in to_load: @@ -213,9 +213,10 @@ def load_lora(lora, to_load): patch_dict[to_load[x]] = ("set", (set_weight,)) loaded_keys.add(set_weight_name) - for x in lora.keys(): - if x not in loaded_keys: - logging.warning("lora key not loaded: {}".format(x)) + if log_missing: + for x in lora.keys(): + if x not in loaded_keys: + logging.warning("lora key not loaded: {}".format(x)) return patch_dict @@ -429,7 +430,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten return padded_tensor -def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): +def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None): for p in patches: strength = p[0] v = p[1] @@ -471,6 +472,11 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) elif patch_type == "set": weight.copy_(v[0]) + elif patch_type == "model_as_lora": + target_weight: torch.Tensor = v[0] + diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ + comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) + weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) diff --git a/comfy/model_base.py b/comfy/model_base.py index c305014a4a3..8f37af660d6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -33,12 +33,16 @@ import comfy.ldm.lightricks.model import comfy.model_management +import comfy.patcher_extension import comfy.conds import comfy.ops from enum import Enum from . import utils import comfy.latent_formats import math +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher class ModelType(Enum): EPS = 1 @@ -95,6 +99,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod self.model_config = model_config self.manual_cast_dtype = model_config.manual_cast_dtype self.device = device + self.current_patcher: 'ModelPatcher' = None if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: @@ -120,6 +125,13 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._apply_model, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options) + ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b64c795aa50..4ae3ad25db2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -16,6 +16,8 @@ along with this program. If not, see . """ +from __future__ import annotations +from typing import Optional, Callable import torch import copy import inspect @@ -28,6 +30,9 @@ import comfy.float import comfy.model_management import comfy.lora +import comfy.hooks +import comfy.patcher_extension +from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.comfy_types import UnetWrapperFunction def string_to_seed(data): @@ -76,6 +81,17 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ model_options["disable_cfg1_optimization"] = True return model_options +def create_model_options_clone(orig_model_options: dict): + return comfy.patcher_extension.copy_nested_dicts(orig_model_options) + +def create_hook_patches_clone(orig_hook_patches): + new_hook_patches = {} + for hook_ref in orig_hook_patches: + new_hook_patches[hook_ref] = {} + for k in orig_hook_patches[hook_ref]: + new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] + return new_hook_patches + def wipe_lowvram_weight(m): if hasattr(m, "prev_comfy_cast_weights"): m.comfy_cast_weights = m.prev_comfy_cast_weights @@ -119,6 +135,49 @@ def get_key_weight(model, key): return weight, set_func, convert_func +class AutoPatcherEjector: + def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): + self.model = model + self.was_injected = False + self.prev_skip_injection = False + self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only + + def __enter__(self): + self.was_injected = False + self.prev_skip_injection = self.model.skip_injection + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = True + if self.model.is_injected: + self.model.eject_model() + self.was_injected = True + + def __exit__(self, *args): + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = self.prev_skip_injection + self.model.inject_model() + if self.was_injected and not self.model.skip_injection: + self.model.inject_model() + self.model.skip_injection = self.prev_skip_injection + +class MemoryCounter: + def __init__(self, initial: int, minimum=0): + self.value = initial + self.minimum = minimum + # TODO: add a safe limit besides 0 + + def use(self, weight: torch.Tensor): + weight_size = weight.nelement() * weight.element_size() + if self.is_useable(weight_size): + self.decrement(weight_size) + return True + return False + + def is_useable(self, used: int): + return self.value - used > self.minimum + + def decrement(self, used: int): + self.value -= used + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -141,6 +200,24 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.patches_uuid = uuid.uuid4() self.parent = None + self.attachments: dict[str] = {} + self.additional_models: dict[str, list[ModelPatcher]] = {} + self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks() + self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers() + + self.is_injected = False + self.skip_injection = False + self.injections: dict[str, list[PatcherInjection]] = {} + + self.hook_patches: dict[comfy.hooks._HookRef] = {} + self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} + self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} + self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} + self.current_hooks: Optional[comfy.hooks.HookGroup] = None + self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time + self.is_clip = False + self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -177,6 +254,47 @@ def clone(self): n.backup = self.backup n.object_patches_backup = self.object_patches_backup n.parent = self + + # attachments + n.attachments = {} + for k in self.attachments: + if hasattr(self.attachments[k], "on_model_patcher_clone"): + n.attachments[k] = self.attachments[k].on_model_patcher_clone() + else: + n.attachments[k] = self.attachments[k] + # additional models + for k, c in self.additional_models.items(): + n.additional_models[k] = [x.clone() for x in c] + # callbacks + for k, c in self.callbacks.items(): + n.callbacks[k] = {} + for k1, c1 in c.items(): + n.callbacks[k][k1] = c1.copy() + # sample wrappers + for k, w in self.wrappers.items(): + n.wrappers[k] = {} + for k1, w1 in w.items(): + n.wrappers[k][k1] = w1.copy() + # injection + n.is_injected = self.is_injected + n.skip_injection = self.skip_injection + for k, i in self.injections.items(): + n.injections[k] = i.copy() + # hooks + n.hook_patches = create_hook_patches_clone(self.hook_patches) + n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) + for group in self.cached_hook_patches: + n.cached_hook_patches[group] = {} + for k in self.cached_hook_patches[group]: + n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k] + n.hook_backup = self.hook_backup + n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks + n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks + n.is_clip = self.is_clip + n.hook_mode = self.hook_mode + + for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): + callback(self, n) return n def is_clone(self, other): @@ -184,10 +302,29 @@ def is_clone(self, other): return True return False - def clone_has_same_weights(self, clone): + def clone_has_same_weights(self, clone: 'ModelPatcher'): if not self.is_clone(clone): return False + if self.current_hooks != clone.current_hooks: + return False + if self.forced_hooks != clone.forced_hooks: + return False + if self.hook_patches.keys() != clone.hook_patches.keys(): + return False + if self.attachments.keys() != clone.attachments.keys(): + return False + if self.additional_models.keys() != clone.additional_models.keys(): + return False + for key in self.callbacks: + if len(self.callbacks[key]) != len(clone.callbacks[key]): + return False + for key in self.wrappers: + if len(self.wrappers[key]) != len(clone.wrappers[key]): + return False + if self.injections.keys() != clone.injections.keys(): + return False + if len(self.patches) == 0 and len(clone.patches) == 0: return True @@ -256,6 +393,12 @@ def set_model_input_block_patch_after_skip(self, patch): def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") + def set_model_emb_patch(self, patch): + self.set_model_patch(patch, "emb_patch") + + def set_model_forward_timestep_embed_patch(self, patch): + self.set_model_patch(patch, "forward_timestep_embed_patch") + def add_object_patch(self, name, obj): self.object_patches[name] = obj @@ -294,27 +437,28 @@ def model_dtype(self): return self.model.get_dtype() def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] + with self.use_ejected(): + p = set() + model_sd = self.model.state_dict() + for k in patches: + offset = None + function = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] - if key in model_sd: - p.add(k) - current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - self.patches[key] = current_patches + if key in model_sd: + p.add(k) + current_patches = self.patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + self.patches[key] = current_patches - self.patches_uuid = uuid.uuid4() - return list(p) + self.patches_uuid = uuid.uuid4() + return list(p) def get_key_patches(self, filter_prefix=None): model_sd = self.model_state_dict() @@ -324,9 +468,12 @@ def get_key_patches(self, filter_prefix=None): if not k.startswith(filter_prefix): continue bk = self.backup.get(k, None) + hbk = self.hook_backup.get(k, None) weight, set_func, convert_func = get_key_weight(self.model, k) if bk is not None: weight = bk.weight + if hbk is not None: + weight = hbk[0] if convert_func is None: convert_func = lambda a, **kwargs: a @@ -337,13 +484,14 @@ def get_key_patches(self, filter_prefix=None): return p def model_state_dict(self, filter_prefix=None): - sd = self.model.state_dict() - keys = list(sd.keys()) - if filter_prefix is not None: - for k in keys: - if not k.startswith(filter_prefix): - sd.pop(k) - return sd + with self.use_ejected(): + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd def patch_weight_to_device(self, key, device_to=None, inplace_update=False): if key not in self.patches: @@ -388,106 +536,117 @@ def _load_list(self): return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): - mem_counter = 0 - patch_counter = 0 - lowvram_counter = 0 - loading = self._load_list() - - load_completely = [] - loading.sort(reverse=True) - for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] - - lowvram_weight = False - - if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: - lowvram_weight = True - lowvram_counter += 1 - if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed - continue + with self.use_ejected(): + self.unpatch_hooks() + mem_counter = 0 + patch_counter = 0 + lowvram_counter = 0 + loading = self._load_list() + + load_completely = [] + loading.sort(reverse=True) + for x in loading: + n = x[1] + m = x[2] + params = x[3] + module_mem = x[0] + + lowvram_weight = False + + if not full_load and hasattr(m, "comfy_cast_weights"): + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + lowvram_counter += 1 + if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed + continue - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) - if lowvram_weight: - if weight_key in self.patches: - if force_patch_weights: - self.patch_weight_to_device(weight_key) - else: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - if force_patch_weights: - self.patch_weight_to_device(bias_key) - else: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 + if lowvram_weight: + if weight_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - else: - if hasattr(m, "comfy_cast_weights"): - if m.comfy_cast_weights: - wipe_lowvram_weight(m) - - if full_load or mem_counter + module_mem < lowvram_model_memory: - mem_counter += module_mem - load_completely.append((module_mem, n, m, params)) - - load_completely.sort(reverse=True) - for x in load_completely: - n = x[1] - m = x[2] - params = x[3] - if hasattr(m, "comfy_patched_weights"): - if m.comfy_patched_weights == True: - continue + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "comfy_cast_weights"): + if m.comfy_cast_weights: + wipe_lowvram_weight(m) + + if full_load or mem_counter + module_mem < lowvram_model_memory: + mem_counter += module_mem + load_completely.append((module_mem, n, m, params)) + + load_completely.sort(reverse=True) + for x in load_completely: + n = x[1] + m = x[2] + params = x[3] + if hasattr(m, "comfy_patched_weights"): + if m.comfy_patched_weights == True: + continue - for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + for param in params: + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) - logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) - m.comfy_patched_weights = True + logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + m.comfy_patched_weights = True - for x in load_completely: - x[2].to(device_to) + for x in load_completely: + x[2].to(device_to) - if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) - self.model.model_lowvram = True - else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) - self.model.model_lowvram = False - if full_load: - self.model.to(device_to) - mem_counter = self.model_size() + if lowvram_counter > 0: + logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + self.model.model_lowvram = True + else: + logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + self.model.model_lowvram = False + if full_load: + self.model.to(device_to) + mem_counter = self.model_size() - self.model.lowvram_patch_counter += patch_counter - self.model.device = device_to - self.model.model_loaded_weight_memory = mem_counter - self.model.current_weight_patches_uuid = self.patches_uuid + self.model.lowvram_patch_counter += patch_counter + self.model.device = device_to + self.model.model_loaded_weight_memory = mem_counter + self.model.current_weight_patches_uuid = self.patches_uuid - def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): - for k in self.object_patches: - old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) - if k not in self.object_patches_backup: - self.object_patches_backup[k] = old + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) - if lowvram_model_memory == 0: - full_load = True - else: - full_load = False + self.apply_hooks(self.forced_hooks, force_apply=True) + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + with self.use_ejected(): + for k in self.object_patches: + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old + + if lowvram_model_memory == 0: + full_load = True + else: + full_load = False - if load_weights: - self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) + if load_weights: + self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) + self.inject_model() return self.model def unpatch_model(self, device_to=None, unpatch_weights=True): + self.eject_model() if unpatch_weights: + self.unpatch_hooks() if self.model.model_lowvram: for m in self.model.modules(): wipe_lowvram_weight(m) @@ -523,85 +682,91 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.object_patches_backup.clear() def partially_unload(self, device_to, memory_to_free=0): - memory_freed = 0 - patch_counter = 0 - unload_list = self._load_list() - unload_list.sort() - for unload in unload_list: - if memory_to_free < memory_freed: - break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] - - lowvram_possible = hasattr(m, "comfy_cast_weights") - if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: - move_weight = True - for param in params: - key = "{}.{}".format(n, param) - bk = self.backup.get(key, None) - if bk is not None: - if not lowvram_possible: - move_weight = False - break - - if bk.inplace_update: - comfy.utils.copy_to_param(self.model, key, bk.weight) - else: - comfy.utils.set_attr_param(self.model, key, bk.weight) - self.backup.pop(key) - - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if move_weight: - m.to(device_to) - if lowvram_possible: - if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 - - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - m.comfy_patched_weights = False - memory_freed += module_mem - logging.debug("freed {}".format(n)) + with self.use_ejected(): + memory_freed = 0 + patch_counter = 0 + unload_list = self._load_list() + unload_list.sort() + for unload in unload_list: + if memory_to_free < memory_freed: + break + module_mem = unload[0] + n = unload[1] + m = unload[2] + params = unload[3] + + lowvram_possible = hasattr(m, "comfy_cast_weights") + if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: + move_weight = True + for param in params: + key = "{}.{}".format(n, param) + bk = self.backup.get(key, None) + if bk is not None: + if not lowvram_possible: + move_weight = False + break + + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, key, bk.weight) + else: + comfy.utils.set_attr_param(self.model, key, bk.weight) + self.backup.pop(key) + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if move_weight: + m.to(device_to) + if lowvram_possible: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + m.comfy_patched_weights = False + memory_freed += module_mem + logging.debug("freed {}".format(n)) - self.model.model_lowvram = True - self.model.lowvram_patch_counter += patch_counter - self.model.model_loaded_weight_memory -= memory_freed - return memory_freed + self.model.model_lowvram = True + self.model.lowvram_patch_counter += patch_counter + self.model.model_loaded_weight_memory -= memory_freed + return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): - unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) - # TODO: force_patch_weights should not unload + reload full model - used = self.model.model_loaded_weight_memory - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) - if unpatch_weights: - extra_memory += (used - self.model.model_loaded_weight_memory) - - self.patch_model(load_weights=False) - full_load = False - if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: - return 0 - if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): - full_load = True - current_used = self.model.model_loaded_weight_memory - try: - self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) - except Exception as e: - self.detach() - raise e - - return self.model.model_loaded_weight_memory - current_used + with self.use_ejected(skip_and_inject_on_exit_only=True): + unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) + # TODO: force_patch_weights should not unload + reload full model + used = self.model.model_loaded_weight_memory + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + if unpatch_weights: + extra_memory += (used - self.model.model_loaded_weight_memory) + + self.patch_model(load_weights=False) + full_load = False + if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: + self.apply_hooks(self.forced_hooks, force_apply=True) + return 0 + if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): + full_load = True + current_used = self.model.model_loaded_weight_memory + try: + self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) + except Exception as e: + self.detach() + raise e + + return self.model.model_loaded_weight_memory - current_used def detach(self, unpatch_all=True): + self.eject_model() self.model_patches_to(self.offload_device) if unpatch_all: self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): + callback(self, unpatch_all) return self.model def current_loaded_device(self): @@ -611,6 +776,345 @@ def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float3 print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) + def cleanup(self): + self.clean_hooks() + if hasattr(self.model, "current_patcher"): + self.model.current_patcher = None + for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP): + callback(self) + + def add_callback(self, call_type: str, callback: Callable): + self.add_callback_with_key(call_type, None, callback) + + def add_callback_with_key(self, call_type: str, key: str, callback: Callable): + c = self.callbacks.setdefault(call_type, {}).setdefault(key, []) + c.append(callback) + + def remove_callbacks_with_key(self, call_type: str, key: str): + c = self.callbacks.get(call_type, {}) + if key in c: + c.pop(key) + + def get_callbacks(self, call_type: str, key: str): + return self.callbacks.get(call_type, {}).get(key, []) + + def get_all_callbacks(self, call_type: str): + c_list = [] + for c in self.callbacks.get(call_type, {}).values(): + c_list.extend(c) + return c_list + + def add_wrapper(self, wrapper_type: str, wrapper: Callable): + self.add_wrapper_with_key(wrapper_type, None, wrapper) + + def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): + w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) + w.append(wrapper) + + def remove_wrappers_with_key(self, wrapper_type: str, key: str): + w = self.wrappers.get(wrapper_type, {}) + if key in w: + w.pop(key) + + def get_wrappers(self, wrapper_type: str, key: str): + return self.wrappers.get(wrapper_type, {}).get(key, []) + + def get_all_wrappers(self, wrapper_type: str): + w_list = [] + for w in self.wrappers.get(wrapper_type, {}).values(): + w_list.extend(w) + return w_list + + def set_attachments(self, key: str, attachment): + self.attachments[key] = attachment + + def remove_attachments(self, key: str): + if key in self.attachments: + self.attachments.pop(key) + + def get_attachment(self, key: str): + return self.attachments.get(key, None) + + def set_injections(self, key: str, injections: list[PatcherInjection]): + self.injections[key] = injections + + def remove_injections(self, key: str): + if key in self.injections: + self.injections.pop(key) + + def set_additional_models(self, key: str, models: list['ModelPatcher']): + self.additional_models[key] = models + + def remove_additional_models(self, key: str): + if key in self.additional_models: + self.additional_models.pop(key) + + def get_additional_models_with_key(self, key: str): + return self.additional_models.get(key, []) + + def get_additional_models(self): + all_models = [] + for models in self.additional_models.values(): + all_models.extend(models) + return all_models + + def get_nested_additional_models(self): + def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]): + '''Make sure circular references do not cause infinite recursion.''' + next_models = [] + for model in prev_models: + candidates = model.get_additional_models() + for c in candidates: + if c not in cache_set: + next_models.append(c) + cache_set.add(c) + if len(next_models) == 0: + return prev_models + return prev_models + _evaluate_sub_additional_models(next_models, cache_set) + + all_models = self.get_additional_models() + models_set = set(all_models) + real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set) + return real_all_models + + def use_ejected(self, skip_and_inject_on_exit_only=False): + return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only) + + def inject_model(self): + if self.is_injected or self.skip_injection: + return + for injections in self.injections.values(): + for inj in injections: + inj.inject(self) + self.is_injected = True + if self.is_injected: + for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL): + callback(self) + + def eject_model(self): + if not self.is_injected: + return + for injections in self.injections.values(): + for inj in injections: + inj.eject(self) + self.is_injected = False + for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL): + callback(self) + + def pre_run(self): + if hasattr(self.model, "current_patcher"): + self.model.current_patcher = self + for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): + callback(self) + + def prepare_state(self, timestep): + for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): + callback(self, timestep) + + def restore_hook_patches(self): + if len(self.hook_patches_backup) > 0: + self.hook_patches = self.hook_patches_backup + self.hook_patches_backup = {} + + def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): + self.hook_mode = hook_mode + + def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): + curr_t = t[0] + reset_current_hooks = False + for hook in hook_group.hooks: + changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t) + # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; + # this will cause the weights to be recalculated when sampling + if changed: + # reset current_hooks if contains hook that changed + if self.current_hooks is not None: + for current_hook in self.current_hooks.hooks: + if current_hook == hook: + reset_current_hooks = True + break + for cached_group in list(self.cached_hook_patches.keys()): + if cached_group.contains(hook): + self.cached_hook_patches.pop(cached_group) + if reset_current_hooks: + self.patch_hooks(None) + + def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): + self.restore_hook_patches() + registered_hooks: list[comfy.hooks.Hook] = [] + # handle WrapperHooks, if model_options provided + if model_options is not None: + for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): + hook.add_hook_patches(self, model_options, target, registered_hooks) + # handle WeightHooks + weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] + for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + if hook.hook_ref not in self.hook_patches: + weight_hooks_to_register.append(hook) + if len(weight_hooks_to_register) > 0: + # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state + self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) + for hook in weight_hooks_to_register: + hook.add_hook_patches(self, model_options, target, registered_hooks) + for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): + callback(self, hooks_dict, target) + + def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): + with self.use_ejected(): + # NOTE: this mirrors behavior of add_patches func + current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {}) + p = set() + model_sd = self.model.state_dict() + for k in patches: + offset = None + function = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] + + if key in model_sd: + p.add(k) + current_patches: list[tuple] = current_hook_patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + current_hook_patches[key] = current_patches + self.hook_patches[hook.hook_ref] = current_hook_patches + # since should care about these patches too to determine if same model, reroll patches_uuid + self.patches_uuid = uuid.uuid4() + return list(p) + + def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup): + # combined_patches will contain weights of all relevant hooks, per key + combined_patches = {} + if hooks is not None: + for hook in hooks.hooks: + hook_patches: dict = self.hook_patches.get(hook.hook_ref, {}) + for key in hook_patches.keys(): + current_patches: list[tuple] = combined_patches.get(key, []) + if math.isclose(hook.strength, 1.0): + current_patches.extend(hook_patches[key]) + else: + # patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model) + for patch in hook_patches[key]: + new_patch = list(patch) + new_patch[0] *= hook.strength + current_patches.append(tuple(new_patch)) + combined_patches[key] = current_patches + return combined_patches + + def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): + # TODO: return transformer_options dict with any additions from hooks + if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): + return {} + self.patch_hooks(hooks=hooks) + for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): + callback(self, hooks) + return {} + + def patch_hooks(self, hooks: comfy.hooks.HookGroup): + with self.use_ejected(): + self.unpatch_hooks() + if hooks is not None: + model_sd_keys = list(self.model_state_dict().keys()) + memory_counter = None + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + # TODO: minimum_counter should have a minimum that conforms to loaded model requirements + memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device), + minimum=comfy.model_management.minimum_inference_memory()*2) + # if have cached weights for hooks, use it + cached_weights = self.cached_hook_patches.get(hooks, None) + if cached_weights is not None: + for key in cached_weights: + if key not in model_sd_keys: + print(f"WARNING cached hook could not patch. key does not exist in model: {key}") + continue + self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) + else: + relevant_patches = self.get_combined_hook_patches(hooks=hooks) + original_weights = None + if len(relevant_patches) > 0: + original_weights = self.get_key_patches() + for key in relevant_patches: + if key not in model_sd_keys: + print(f"WARNING cached hook would not patch. key does not exist in model: {key}") + continue + self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, + memory_counter=memory_counter) + self.current_hooks = hooks + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + if key not in self.hook_backup: + weight: torch.Tensor = comfy.utils.get_attr(self.model, key) + target_device = self.offload_device + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + used = memory_counter.use(weight) + if used: + target_device = weight.device + self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) + comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1])) + + def clear_cached_hook_weights(self): + self.cached_hook_patches.clear() + self.patch_hooks(None) + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + weight, set_func, convert_func = get_key_weight(self.model, key) + weight: torch.Tensor + if key not in self.hook_backup: + target_device = self.offload_device + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + used = memory_counter.use(weight) + if used: + target_device = weight.device + self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) + # TODO: properly handle LowVramPatch, if it ends up an issue + temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True) + if convert_func is not None: + temp_weight = convert_func(temp_weight, inplace=True) + + out_weight = comfy.lora.calculate_weight(combined_patches[key], + temp_weight, + key, original_weights=original_weights) + del original_weights[key] + if set_func is None: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + # TODO: disable caching if not enough system RAM to do so + target_device = self.offload_device + used = memory_counter.use(weight) + if used: + target_device = weight.device + self.cached_hook_patches.setdefault(hooks, {}) + self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device) + del temp_weight + del out_weight + del weight + + def unpatch_hooks(self) -> None: + with self.use_ejected(): + if len(self.hook_backup) == 0: + self.current_hooks = None + return + keys = list(self.hook_backup.keys()) + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) + + self.hook_backup.clear() + self.current_hooks = None + + def clean_hooks(self): + self.unpatch_hooks() + self.clear_cached_hook_weights() + def __del__(self): self.detach(unpatch_all=False) diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py new file mode 100644 index 00000000000..5144691857b --- /dev/null +++ b/comfy/patcher_extension.py @@ -0,0 +1,156 @@ +from __future__ import annotations +from typing import Callable + +class CallbacksMP: + ON_CLONE = "on_clone" + ON_LOAD = "on_load_after" + ON_DETACH = "on_detach_after" + ON_CLEANUP = "on_cleanup" + ON_PRE_RUN = "on_pre_run" + ON_PREPARE_STATE = "on_prepare_state" + ON_APPLY_HOOKS = "on_apply_hooks" + ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches" + ON_INJECT_MODEL = "on_inject_model" + ON_EJECT_MODEL = "on_eject_model" + + # callbacks dict is in the format: + # {"call_type": {"key": [Callable1, Callable2, ...]} } + @classmethod + def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]: + return {} + +def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False): + add_callback_with_key(call_type, None, callback, transformer_options, is_model_options) + +def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.setdefault("transformer_options", {}) + callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {}) + c = callbacks.setdefault(call_type, {}).setdefault(key, []) + c.append(callback) + +def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.get("transformer_options", {}) + c_list = [] + callbacks: dict[str, list] = transformer_options.get("callbacks", {}) + c_list.extend(callbacks.get(call_type, {}).get(key, [])) + return c_list + +def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.get("transformer_options", {}) + c_list = [] + callbacks: dict[str, list] = transformer_options.get("callbacks", {}) + for c in callbacks.get(call_type, {}).values(): + c_list.extend(c) + return c_list + +class WrappersMP: + OUTER_SAMPLE = "outer_sample" + SAMPLER_SAMPLE = "sampler_sample" + CALC_COND_BATCH = "calc_cond_batch" + APPLY_MODEL = "apply_model" + DIFFUSION_MODEL = "diffusion_model" + + # wrappers dict is in the format: + # {"wrapper_type": {"key": [Callable1, Callable2, ...]} } + @classmethod + def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]: + return {} + +def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False): + add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options) + +def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.setdefault("transformer_options", {}) + wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {}) + w = wrappers.setdefault(wrapper_type, {}).setdefault(key, []) + w.append(wrapper) + +def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.get("transformer_options", {}) + w_list = [] + wrappers: dict[str, list] = transformer_options.get("wrappers", {}) + w_list.extend(wrappers.get(wrapper_type, {}).get(key, [])) + return w_list + +def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False): + if is_model_options: + transformer_options = transformer_options.get("transformer_options", {}) + w_list = [] + wrappers: dict[str, list] = transformer_options.get("wrappers", {}) + for w in wrappers.get(wrapper_type, {}).values(): + w_list.extend(w) + return w_list + +class WrapperExecutor: + """Handles call stack of wrappers around a function in an ordered manner.""" + def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int): + # NOTE: class_obj exists so that wrappers surrounding a class method can access + # the class instance at runtime via executor.class_obj + self.original = original + self.class_obj = class_obj + self.wrappers = wrappers.copy() + self.idx = idx + self.is_last = idx == len(wrappers) + + def __call__(self, *args, **kwargs): + """Calls the next wrapper or original function, whichever is appropriate.""" + new_executor = self._create_next_executor() + return new_executor.execute(*args, **kwargs) + + def execute(self, *args, **kwargs): + """Used to initiate executor internally - DO NOT use this if you received executor in wrapper.""" + args = list(args) + kwargs = dict(kwargs) + if self.is_last: + return self.original(*args, **kwargs) + return self.wrappers[self.idx](self, *args, **kwargs) + + def _create_next_executor(self) -> 'WrapperExecutor': + new_idx = self.idx + 1 + if new_idx > len(self.wrappers): + raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.") + if self.class_obj is None: + return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx) + return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx) + + @classmethod + def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0): + return cls(original, class_obj=None, wrappers=wrappers, idx=idx) + + @classmethod + def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0): + return cls(original, class_obj, wrappers, idx=idx) + +class PatcherInjection: + def __init__(self, inject: Callable, eject: Callable): + self.inject = inject + self.eject = eject + +def copy_nested_dicts(input_dict: dict): + new_dict = input_dict.copy() + for key, value in input_dict.items(): + if isinstance(value, dict): + new_dict[key] = copy_nested_dicts(value) + elif isinstance(value, list): + new_dict[key] = value.copy() + return new_dict + +def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True): + if copy_dict1: + merged_dict = copy_nested_dicts(dict1) + else: + merged_dict = dict1 + for key, value in dict2.items(): + if isinstance(value, dict): + curr_value = merged_dict.setdefault(key, {}) + merged_dict[key] = merge_nested_dicts(value, curr_value) + elif isinstance(value, list): + merged_dict.setdefault(key, []).extend(value) + else: + merged_dict[key] = value + return merged_dict diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1879e670a53..1252d8a5bf6 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,7 +1,16 @@ +from __future__ import annotations +import uuid import torch import comfy.model_management import comfy.conds import comfy.utils +import comfy.hooks +import comfy.patcher_extension +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.model_base import BaseModel + from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): return comfy.utils.reshape_mask(noise_mask, shape).to(device) @@ -10,9 +19,43 @@ def get_models_from_cond(cond, model_type): models = [] for c in cond: if model_type in c: - models += [c[model_type]] + if isinstance(c[model_type], list): + models += c[model_type] + else: + models += [c[model_type]] return models +def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): + # get hooks from conds, and collect cnets so they can be checked for extra_hooks + cnets: list[ControlBase] = [] + for c in cond: + if 'hooks' in c: + for hook in c['hooks'].hooks: + hook: comfy.hooks.Hook + with_type = hooks_dict.setdefault(hook.hook_type, {}) + with_type[hook] = None + if 'control' in c: + cnets.append(c['control']) + + def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list): + if cnet.extra_hooks is not None: + _list.append(cnet.extra_hooks) + if cnet.previous_controlnet is None: + return _list + return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list) + + hooks_list = [] + cnets = set(cnets) + for base_cnet in cnets: + get_extra_hooks_from_cnet(base_cnet, hooks_list) + extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) + if extra_hooks is not None: + for hook in extra_hooks.hooks: + with_type = hooks_dict.setdefault(hook.hook_type, {}) + with_type[hook] = None + + return hooks_dict + def convert_cond(cond): out = [] for c in cond: @@ -22,17 +65,22 @@ def convert_cond(cond): model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove temp["cross_attn"] = c[0] temp["model_conds"] = model_conds + temp["uuid"] = uuid.uuid4() out.append(temp) return out def get_additional_models(conds, dtype): """loads additional models in conditioning""" - cnets = [] + cnets: list[ControlBase] = [] gligen = [] + add_models = [] + hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") + add_models += get_models_from_cond(conds[k], "additional_models") + get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -43,7 +91,9 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - models = control_models + gligen + hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] + models = control_models + gligen + add_models + hook_models + return models, inference_memory def cleanup_additional_models(models): @@ -53,10 +103,11 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model, noise_shape, conds): +def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): device = model.load_device - real_model = None + real_model: 'BaseModel' = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) @@ -72,3 +123,14 @@ def cleanup_models(conds, models): control_cleanup += get_models_from_cond(conds[k], "control") cleanup_additional_models(set(control_cleanup)) + +def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): + # check for hooks in conds - if not registered, see if can be applied + hooks = {} + for k in conds: + get_hooks_from_cond(conds[k], hooks) + # add wrappers and callbacks from ModelPatcher to transformer_options + model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) + model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) + # register hooks on model/model_options + model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) diff --git a/comfy/samplers.py b/comfy/samplers.py index 94cba03b88f..b4c42160d3c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,11 +1,21 @@ +from __future__ import annotations from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.model_base import BaseModel + from comfy.controlnet import ControlBase import torch import collections from comfy import model_management import math import logging +import comfy.samplers import comfy.sampler_helpers +import comfy.model_patcher +import comfy.patcher_extension +import comfy.hooks import scipy.stats import numpy @@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in): for c in model_conds: conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + hooks = conds.get('hooks', None) control = conds.get('control', None) patches = None @@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in): patches['middle_patch'] = [gligen_patch] - cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) - return cond_obj(input_x, mult, conditioning, area, control, patches) + cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks']) + return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks) def cond_equal_size(c1, c2): if c1 is c2: @@ -138,110 +149,184 @@ def cond_cat(c_list): return out -def calc_cond_batch(model, conds, x_in, timestep, model_options): +def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep): + # need to figure out remaining unmasked area for conds + default_mults = [] + for _ in default_conds: + default_mults.append(torch.ones_like(x_in)) + # look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond + for lora_hooks, to_run in hooked_to_run.items(): + for cond_obj, i in to_run: + # if no default_cond for cond_type, do nothing + if len(default_conds[i]) == 0: + continue + area: list[int] = cond_obj.area + if area is not None: + curr_default_mult: torch.Tensor = default_mults[i] + dims = len(area) // 2 + for i in range(dims): + curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i]) + curr_default_mult -= cond_obj.mult + else: + default_mults[i] -= cond_obj.mult + # for each default_mult, ReLU to make negatives=0, and then check for any nonzeros + for i, mult in enumerate(default_mults): + # if no default_cond for cond type, do nothing + if len(default_conds[i]) == 0: + continue + torch.nn.functional.relu(mult, inplace=True) + # if mult is all zeros, then don't add default_cond + if torch.max(mult) == 0.0: + continue + + cond = default_conds[i] + for x in cond: + # do get_area_and_mult to get all the expected values + p = comfy.samplers.get_area_and_mult(x, x_in, timestep) + if p is None: + continue + # replace p's mult with calculated mult + p = p._replace(mult=mult) + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + +def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): + executor = comfy.patcher_extension.WrapperExecutor.new_executor( + _calc_cond_batch, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) + ) + return executor.execute(model, conds, x_in, timestep, model_options) + +def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): out_conds = [] out_counts = [] - to_run = [] + # separate conds by matching hooks + hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} + default_conds = [] + has_default_conds = False for i in range(len(conds)): out_conds.append(torch.zeros_like(x_in)) out_counts.append(torch.ones_like(x_in) * 1e-37) cond = conds[i] + default_c = [] if cond is not None: for x in cond: - p = get_area_and_mult(x, x_in, timestep) + if 'default' in x: + default_c.append(x) + has_default_conds = True + continue + p = comfy.samplers.get_area_and_mult(x, x_in, timestep) if p is None: continue + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + default_conds.append(default_c) + + if has_default_conds: + finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep) + + model.current_patcher.prepare_state(timestep) + + # run every hooked_to_run separately + for hooks, to_run in hooked_to_run.items(): + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = model_management.get_free_memory(x_in.device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) * 1.5 < free_memory: + to_batch = batch_amount + break + + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) + + transformer_options = model.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + # TODO: replace with merge_nested_dicts function + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches - to_run += [(p, i)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) * 1.5 < free_memory: - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x.append(p.input_x) - mult.append(p.mult) - c.append(p.conditioning) - area.append(p.area) - cond_or_uncond.append(o[1]) - control = p.control - patches = p.patches - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches - else: - transformer_options["patches"] = patches + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep + c['transformer_options'] = transformer_options - c['transformer_options'] = transformer_options + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - - for o in range(batch_chunks): - cond_index = cond_or_uncond[o] - a = area[o] - if a is None: - out_conds[cond_index] += output[o] * mult[o] - out_counts[cond_index] += mult[o] + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: - out_c = out_conds[cond_index] - out_cts = out_counts[cond_index] - dims = len(a) // 2 - for i in range(dims): - out_c = out_c.narrow(i + 2, a[i + dims], a[i]) - out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) - out_c += output[o] * mult[o] - out_cts += mult[o] + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -500,10 +585,15 @@ def calculate_start_end_timesteps(model, conds): timestep_start = None timestep_end = None - if 'start_percent' in x: - timestep_start = s.percent_to_sigma(x['start_percent']) - if 'end_percent' in x: - timestep_end = s.percent_to_sigma(x['end_percent']) + # handle clip hook schedule, if needed + if 'clip_start_percent' in x: + timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0))) + timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0))) + else: + if 'start_percent' in x: + timestep_start = s.percent_to_sigma(x['start_percent']) + if 'end_percent' in x: + timestep_end = s.percent_to_sigma(x['end_percent']) if (timestep_start is not None) or (timestep_end is not None): n = x.copy() @@ -673,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N if k != kk: create_cond_with_same_area_if_none(conds[kk], c) + for k in conds: + for c in conds[k]: + if 'hooks' in c: + for hook in c['hooks'].hooks: + hook.initialize_timesteps(model) + for k in conds: pre_run_control(model, conds[k]) @@ -685,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N return conds + +def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]): + # determine which ControlNets have extra_hooks that should be combined with normal hooks + hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {} + for k in conds: + for kk in conds[k]: + if 'control' in kk: + control: 'ControlBase' = kk['control'] + extra_hooks = control.get_extra_hooks() + if len(extra_hooks) > 0: + hooks: comfy.hooks.HookGroup = kk.get('hooks', None) + to_replace = hook_replacement.setdefault((control, hooks), []) + to_replace.append(kk) + # if nothing to replace, do nothing + if len(hook_replacement) == 0: + return + + # for optimal sampling performance, common ControlNets + hook combos should have identical hooks + # on the cond dicts + for key, conds_to_modify in hook_replacement.items(): + control = key[0] + hooks = key[1] + hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks]) + # if combined hooks are not None, set as new hooks for all relevant conds + if hooks is not None: + for cond in conds_to_modify: + cond['hooks'] = hooks + + +def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): + hooks_set = set() + for k in conds: + for kk in conds[k]: + hooks_set.add(kk.get('hooks', None)) + return len(hooks_set) + + class CFGGuider: def __init__(self, model_patcher): - self.model_patcher = model_patcher + self.model_patcher: 'ModelPatcher' = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -714,19 +847,17 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) - extra_args = {"model_options": self.model_options, "seed":seed} + extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed} - samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( + sampler.sample, + sampler, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True) + ) + samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - if sigmas.shape[-1] == 0: - return latent_image - - self.conds = {} - for k in self.original_conds: - self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) - + def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) device = self.model_patcher.load_device @@ -737,14 +868,48 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba latent_image = latent_image.to(device) sigmas = sigmas.to(device) - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + try: + self.model_patcher.pre_run() + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + finally: + self.model_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model - del self.conds del self.loaded_models return output + def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + if sigmas.shape[-1] == 0: + return latent_image + + self.conds = {} + for k in self.original_conds: + self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) + preprocess_conds_hooks(self.conds) + + try: + orig_model_options = self.model_options + self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options) + # if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step) + orig_hook_mode = self.model_patcher.hook_mode + if get_total_hook_groups_in_conds(self.conds) <= 1: + self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram + comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) + executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self.outer_sample, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) + ) + output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + finally: + self.model_options = orig_model_options + self.model_patcher.hook_mode = orig_hook_mode + self.model_patcher.restore_hook_patches() + + del self.conds + return output + def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): cfg_guider = CFGGuider(model) diff --git a/comfy/sd.py b/comfy/sd.py index e2af7078121..ebae7f99688 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,8 +1,10 @@ +from __future__ import annotations import torch from enum import Enum import logging from comfy import model_management +from comfy.utils import ProgressBar from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_c_coder import StageC_coder @@ -33,6 +35,7 @@ import comfy.model_patcher import comfy.lora import comfy.lora_convert +import comfy.hooks import comfy.t2i_adapter.adapter import comfy.taesd.taesd @@ -98,9 +101,13 @@ def __init__(self, target=None, embedding_directory=None, no_init=False, tokeniz self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram + self.patcher.is_clip = True + self.apply_hooks_to_conds = None if params['device'] == load_device: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None + self.use_clip_schedule = False logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) def clone(self): @@ -109,6 +116,8 @@ def clone(self): n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx + n.use_clip_schedule = self.use_clip_schedule + n.apply_hooks_to_conds = self.apply_hooks_to_conds return n def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): @@ -120,6 +129,69 @@ def clip_layer(self, layer_idx): def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) + def add_hooks_to_dict(self, pooled_dict: dict[str]): + if self.apply_hooks_to_conds: + pooled_dict["hooks"] = self.apply_hooks_to_conds + return pooled_dict + + def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True): + all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = [] + all_hooks = self.patcher.forced_hooks + if all_hooks is None or not self.use_clip_schedule: + # if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict + return_pooled = "unprojected" if unprojected else True + pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True) + cond = pooled_dict.pop("cond") + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + all_cond_pooled.append([cond, pooled_dict]) + else: + scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule() + + self.cond_stage_model.reset_clip_options() + if self.layer_idx is not None: + self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) + if unprojected: + self.cond_stage_model.set_clip_options({"projected_pooled": False}) + + self.load_model() + all_hooks.reset() + self.patcher.patch_hooks(None) + if show_pbar: + pbar = ProgressBar(len(scheduled_keyframes)) + + for scheduled_opts in scheduled_keyframes: + t_range = scheduled_opts[0] + # don't bother encoding any conds outside of start_percent and end_percent bounds + if "start_percent" in add_dict: + if t_range[1] < add_dict["start_percent"]: + continue + if "end_percent" in add_dict: + if t_range[0] > add_dict["end_percent"]: + continue + hooks_keyframes = scheduled_opts[1] + for hook, keyframe in hooks_keyframes: + hook.hook_keyframe._current_keyframe = keyframe + # apply appropriate hooks with values that match new hook_keyframe + self.patcher.patch_hooks(all_hooks) + # perform encoding as normal + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + pooled_dict = {"pooled_output": pooled} + # add clip_start_percent and clip_end_percent in pooled + pooled_dict["clip_start_percent"] = t_range[0] + pooled_dict["clip_end_percent"] = t_range[1] + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + # add hooks stored on clip + self.add_hooks_to_dict(pooled_dict) + all_cond_pooled.append([cond, pooled_dict]) + if show_pbar: + pbar.update(1) + model_management.throw_exception_if_processing_interrupted() + all_hooks.reset() + return all_cond_pooled + def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): self.cond_stage_model.reset_clip_options() @@ -137,6 +209,7 @@ def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): if len(o) > 2: for k in o[2]: out[k] = o[2][k] + self.add_hooks_to_dict(out) return out if return_pooled: diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index 3087b917b41..b8e241578e7 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -17,8 +17,7 @@ def INPUT_TYPES(s): def encode(self, clip, ascore, width, height, text): tokens = clip.tokenize(text) - cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), ) class CLIPTextEncodeSDXL: @classmethod @@ -47,8 +46,7 @@ def encode(self, clip, width, height, crop_w, crop_h, target_width, target_heigh tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), ) NODE_CLASS_MAPPINGS = { "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index b690432b55b..2ae23f73550 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -18,10 +18,7 @@ def encode(self, clip, clip_l, t5xxl, guidance): tokens = clip.tokenize(clip_l) tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] - output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) - cond = output.pop("cond") - output["guidance"] = guidance - return ([[cond, output]], ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) class FluxGuidance: @classmethod diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py new file mode 100644 index 00000000000..7b5344e5eac --- /dev/null +++ b/comfy_extras/nodes_hooks.py @@ -0,0 +1,697 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Union +import torch +from collections.abc import Iterable + +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.sd import CLIP + +import comfy.hooks +import comfy.sd +import comfy.utils +import folder_paths + +########################################### +# Mask, Combine, and Hook Conditioning +#------------------------------------------ +class PairConditioningSetProperties: + NodeId = 'PairConditioningSetProperties' + NodeName = 'Cond Pair Set Props' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive_NEW": ("CONDITIONING", ), + "negative_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_properties" + + def set_properties(self, positive_NEW, negative_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_positive, final_negative) + +class PairConditioningSetPropertiesAndCombine: + NodeId = 'PairConditioningSetPropertiesAndCombine' + NodeName = 'Cond Pair Set Props Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "positive_NEW": ("CONDITIONING", ), + "negative_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_properties" + + def set_properties(self, positive, negative, positive_NEW, negative_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_positive, final_negative) + +class ConditioningSetProperties: + NodeId = 'ConditioningSetProperties' + NodeName = 'Cond Set Props' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_properties" + + def set_properties(self, cond_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + (final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_cond,) + +class ConditioningSetPropertiesAndCombine: + NodeId = 'ConditioningSetPropertiesAndCombine' + NodeName = 'Cond Set Props Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond": ("CONDITIONING", ), + "cond_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_properties" + + def set_properties(self, cond, cond_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + (final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_cond,) + +class PairConditioningCombine: + NodeId = 'PairConditioningCombine' + NodeName = 'Cond Pair Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive_A": ("CONDITIONING",), + "negative_A": ("CONDITIONING",), + "positive_B": ("CONDITIONING",), + "negative_B": ("CONDITIONING",), + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "combine" + + def combine(self, positive_A, negative_A, positive_B, negative_B): + final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) + return (final_positive, final_negative,) + +class PairConditioningSetDefaultAndCombine: + NodeId = 'PairConditioningSetDefaultCombine' + NodeName = 'Cond Pair Set Default Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "positive_DEFAULT": ("CONDITIONING",), + "negative_DEFAULT": ("CONDITIONING",), + }, + "optional": { + "hooks": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_default_and_combine" + + def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, + hooks: comfy.hooks.HookGroup=None): + final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], + hooks=hooks) + return (final_positive, final_negative) + +class ConditioningSetDefaultAndCombine: + NodeId = 'ConditioningSetDefaultCombine' + NodeName = 'Cond Set Default Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond": ("CONDITIONING",), + "cond_DEFAULT": ("CONDITIONING",), + }, + "optional": { + "hooks": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_default_and_combine" + + def set_default_and_combine(self, cond, cond_DEFAULT, + hooks: comfy.hooks.HookGroup=None): + (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT], + hooks=hooks) + return (final_conditioning,) + +class SetClipHooks: + NodeId = 'SetClipHooks' + NodeName = 'Set CLIP Hooks' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "clip": ("CLIP",), + "apply_to_conds": ("BOOLEAN", {"default": True}), + "schedule_clip": ("BOOLEAN", {"default": False}) + }, + "optional": { + "hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CLIP",) + CATEGORY = "advanced/hooks/clip" + FUNCTION = "apply_hooks" + + def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): + if hooks is not None: + clip = clip.clone() + if apply_to_conds: + clip.apply_hooks_to_conds = hooks + clip.patcher.forced_hooks = hooks.clone() + clip.use_clip_schedule = schedule_clip + if not clip.use_clip_schedule: + clip.patcher.forced_hooks.set_keyframes_on_hooks(None) + clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) + return (clip,) + +class ConditioningTimestepsRange: + NodeId = 'ConditioningTimestepsRange' + NodeName = 'Timesteps Range' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE") + RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE") + CATEGORY = "advanced/hooks" + FUNCTION = "create_range" + + def create_range(self, start_percent: float, end_percent: float): + return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0)) +#------------------------------------------ +########################################### + + +########################################### +# Create Hooks +#------------------------------------------ +class CreateHookLora: + NodeId = 'CreateHookLora' + NodeName = 'Create Hook LoRA' + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook" + + def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None): + if prev_hooks is None: + prev_hooks = comfy.hooks.HookGroup() + prev_hooks.clone() + + if strength_model == 0 and strength_clip == 0: + return (prev_hooks,) + + lora_path = folder_paths.get_full_path("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + temp = self.loaded_lora + self.loaded_lora = None + del temp + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip) + return (prev_hooks.clone_and_combine(hooks),) + +class CreateHookLoraModelOnly(CreateHookLora): + NodeId = 'CreateHookLoraModelOnly' + NodeName = 'Create Hook LoRA (MO)' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook_model_only" + + def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None): + return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks) + +class CreateHookModelAsLora: + NodeId = 'CreateHookModelAsLora' + NodeName = 'Create Hook Model as LoRA' + + def __init__(self): + # when not None, will be in following format: + # (ckpt_path: str, weights_model: dict, weights_clip: dict) + self.loaded_weights = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook" + + def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float, + prev_hooks: comfy.hooks.HookGroup=None): + if prev_hooks is None: + prev_hooks = comfy.hooks.HookGroup() + prev_hooks.clone() + + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + weights_model = None + weights_clip = None + if self.loaded_weights is not None: + if self.loaded_weights[0] == ckpt_path: + weights_model = self.loaded_weights[1] + weights_clip = self.loaded_weights[2] + else: + temp = self.loaded_weights + self.loaded_weights = None + del temp + + if weights_model is None: + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + weights_model = comfy.hooks.get_patch_weights_from_model(out[0]) + weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1]) + self.loaded_weights = (ckpt_path, weights_model, weights_clip) + + hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip, + strength_model=strength_model, strength_clip=strength_clip) + return (prev_hooks.clone_and_combine(hooks),) + +class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora): + NodeId = 'CreateHookModelAsLoraModelOnly' + NodeName = 'Create Hook Model as LoRA (MO)' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook_model_only" + + def create_hook_model_only(self, ckpt_name: str, strength_model: float, + prev_hooks: comfy.hooks.HookGroup=None): + return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks) +#------------------------------------------ +########################################### + + +########################################### +# Schedule Hooks +#------------------------------------------ +class SetHookKeyframes: + NodeId = 'SetHookKeyframes' + NodeName = 'Set Hook Keyframes' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "hooks": ("HOOKS",), + }, + "optional": { + "hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "set_hook_keyframes" + + def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None): + if hook_kf is not None: + hooks = hooks.clone() + hooks.set_keyframes_on_hooks(hook_kf=hook_kf) + return (hooks,) + +class CreateHookKeyframe: + NodeId = 'CreateHookKeyframe' + NodeName = 'Create Hook Keyframe' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "optional": { + "prev_hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOK_KEYFRAMES",) + RETURN_NAMES = ("HOOK_KF",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "create_hook_keyframe" + + def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None): + if prev_hook_kf is None: + prev_hook_kf = comfy.hooks.HookKeyframeGroup() + prev_hook_kf = prev_hook_kf.clone() + keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent) + prev_hook_kf.add(keyframe) + return (prev_hook_kf,) + +class CreateHookKeyframesFromFloats: + NodeId = 'CreateHookKeyframesFromFloats' + NodeName = 'Create Hook Keyframes From Floats' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOK_KEYFRAMES",) + RETURN_NAMES = ("HOOK_KF",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "create_hook_keyframes" + + def create_hook_keyframes(self, floats_strength: Union[float, list[float]], + start_percent: float, end_percent: float, + prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False): + if prev_hook_kf is None: + prev_hook_kf = comfy.hooks.HookKeyframeGroup() + prev_hook_kf = prev_hook_kf.clone() + if type(floats_strength) in (float, int): + floats_strength = [float(floats_strength)] + elif isinstance(floats_strength, Iterable): + pass + else: + raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.") + percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength), + method=comfy.hooks.InterpolationMethod.LINEAR) + + is_first = True + for percent, strength in zip(percents, floats_strength): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) + if print_keyframes: + print(f"Hook Keyframe - start_percent:{percent} = {strength}") + return (prev_hook_kf,) +#------------------------------------------ +########################################### + + +class SetModelHooksOnCond: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "conditioning": ("CONDITIONING",), + "hooks": ("HOOKS",), + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/manual" + FUNCTION = "attach_hook" + + def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup): + return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),) + + +########################################### +# Combine Hooks +#------------------------------------------ +class CombineHooks: + NodeId = 'CombineHooks2' + NodeName = 'Combine Hooks [2]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) + +class CombineHooksFour: + NodeId = 'CombineHooks4' + NodeName = 'Combine Hooks [4]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + "hooks_C": ("HOOKS",), + "hooks_D": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None, + hooks_C: comfy.hooks.HookGroup=None, + hooks_D: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B, hooks_C, hooks_D] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) + +class CombineHooksEight: + NodeId = 'CombineHooks8' + NodeName = 'Combine Hooks [8]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + "hooks_C": ("HOOKS",), + "hooks_D": ("HOOKS",), + "hooks_E": ("HOOKS",), + "hooks_F": ("HOOKS",), + "hooks_G": ("HOOKS",), + "hooks_H": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None, + hooks_C: comfy.hooks.HookGroup=None, + hooks_D: comfy.hooks.HookGroup=None, + hooks_E: comfy.hooks.HookGroup=None, + hooks_F: comfy.hooks.HookGroup=None, + hooks_G: comfy.hooks.HookGroup=None, + hooks_H: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) +#------------------------------------------ +########################################### + +node_list = [ + # Create + CreateHookLora, + CreateHookLoraModelOnly, + CreateHookModelAsLora, + CreateHookModelAsLoraModelOnly, + # Scheduling + SetHookKeyframes, + CreateHookKeyframe, + CreateHookKeyframesFromFloats, + # Combine + CombineHooks, + CombineHooksFour, + CombineHooksEight, + # Attach + ConditioningSetProperties, + ConditioningSetPropertiesAndCombine, + PairConditioningSetProperties, + PairConditioningSetPropertiesAndCombine, + ConditioningSetDefaultAndCombine, + PairConditioningSetDefaultAndCombine, + PairConditioningCombine, + SetClipHooks, + # Other + ConditioningTimestepsRange, +] +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +for node in node_list: + NODE_CLASS_MAPPINGS[node.NodeId] = node + NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index b03eaf6a204..2bd295e2459 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -15,9 +15,7 @@ def encode(self, clip, bert, mt5xl): tokens = clip.tokenize(bert) tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] - output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) - cond = output.pop("cond") - return ([[cond, output]], ) + return (clip.encode_from_tokens_scheduled(tokens), ) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 6ef3c293d98..d75b29e606f 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -82,8 +82,7 @@ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding): tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled}]], ) + return (clip.encode_from_tokens_scheduled(tokens), ) class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): diff --git a/nodes.py b/nodes.py index fb504da3543..260bb5e153f 100644 --- a/nodes.py +++ b/nodes.py @@ -62,9 +62,8 @@ def INPUT_TYPES(s): def encode(self, clip, text): tokens = clip.tokenize(text) - output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) - cond = output.pop("cond") - return ([[cond, output]], ) + return (clip.encode_from_tokens_scheduled(tokens), ) + class ConditioningCombine: @classmethod @@ -2149,6 +2148,7 @@ def init_builtin_extra_nodes(): "nodes_mochi.py", "nodes_slg.py", "nodes_lt.py", + "nodes_hooks.py", ] import_failed = []