Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SD3 ControlNet support + New ComfyUI Compatibility #123

Merged
merged 11 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 57 additions & 45 deletions adv_control/control.py

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions adv_control/control_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def create_combo(reference_type: str, style_fidelity: float, ref_weight: float,


class ReferencePreprocWrapper(AbstractPreprocWrapper):
error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of Reference preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
def __init__(self, condhint: Tensor):
super().__init__(condhint)

Expand All @@ -228,10 +228,12 @@ class ReferenceAdvanced(ControlBase, AdvancedControlBase):

def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup, device=None):
super().__init__(device)
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True)
# TODO: allow vae_optional to be used instead of preprocessor
#require_vae=True
self.ref_opts = ref_opts
self.order = 0
self.latent_format = None
self.model_latent_format = None
self.model_sampling_current = None
self.should_apply_attn_effective_strength = False
self.should_apply_adain_effective_strength = False
Expand Down Expand Up @@ -288,9 +290,9 @@ def should_run(self):

def pre_run_advanced(self, model, percent_to_timestep_function):
AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function)
if type(self.cond_hint_original) == ReferencePreprocWrapper:
if isinstance(self.cond_hint_original, AbstractPreprocWrapper):
self.cond_hint_original = self.cond_hint_original.condhint
self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
self.model_sampling_current = model.model_sampling
# SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments
if type(model).__name__ == "SDXL":
Expand Down Expand Up @@ -328,7 +330,7 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
# noise cond_hint based on sigma (current step)
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.model_latent_format.process_in(self.cond_hint)
self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None)
timestep = self.model_sampling_current.timestep(t)
self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0))
Expand All @@ -343,8 +345,8 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):

def cleanup_advanced(self):
super().cleanup_advanced()
del self.latent_format
self.latent_format = None
del self.model_latent_format
self.model_latent_format = None
del self.model_sampling_current
self.model_sampling_current = None
self.should_apply_attn_effective_strength = False
Expand Down
45 changes: 10 additions & 35 deletions adv_control/control_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from comfy.ldm.modules.attention import SpatialTransformer
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default
from comfy.ldm.modules.attention import FeedForward, SpatialTransformer
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential
from comfy.model_patcher import ModelPatcher
import comfy.ops
import comfy.model_management

from .logger import logger
from .utils import (BIGMAX, TimestepKeyframeGroup, disable_weight_init_clean_groupnorm,
from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm,
prepare_mask_batch, broadcast_image_to_extend, extend_to_batch_size)


Expand Down Expand Up @@ -85,7 +85,8 @@ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs)
x = torch.zeros_like(x)
guided_hint = self.input_hint_block(hint, emb, context)

outs = []
out_output = []
out_middle = []

hs = []
if self.num_classes is not None:
Expand All @@ -100,12 +101,12 @@ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs)
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
out_output.append(zero_conv(h, emb, context))

h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
out_middle.append(self.middle_block_out(h, emb, context))

return outs
return {"middle": out_middle, "output": out_output}


class SparseModelPatcher(ModelPatcher):
Expand Down Expand Up @@ -154,36 +155,10 @@ def clone(self):
self.object_patches_backup = n.object_patches_backup


class PreprocSparseRGBWrapper:
error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
class PreprocSparseRGBWrapper(AbstractPreprocWrapper):
error_msg = error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
def __init__(self, condhint: Tensor):
self.condhint = condhint

def movedim(self, *args, **kwargs):
return self

def __getattr__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __setattr__(self, name, value):
if name != "condhint":
raise AttributeError(self.error_msg)
super().__setattr__(name, value)

def __iter__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __next__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __len__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __getitem__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __setitem__(self, *args, **kwargs):
raise AttributeError(self.error_msg)
super().__init__(condhint)


class SparseContextAware:
Expand Down
9 changes: 5 additions & 4 deletions adv_control/control_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):

guided_hint = self.input_hint_block(hint, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)

outs = []
out_output = []
out_middle = []

hs = []
if self.num_classes is not None:
Expand All @@ -326,12 +327,12 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
guided_hint = None
else:
h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
outs.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
out_output.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))

h = self.middle_block(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
outs.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
out_middle.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))

return outs
return {"middle": out_middle, "output": out_output}


TEMPORAL_TRANSFORMER_BLOCKS = {
Expand Down
20 changes: 16 additions & 4 deletions adv_control/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import folder_paths
from comfy.model_patcher import ModelPatcher

from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet
from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, BIGMAX
from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet, is_sd3_advanced_controlnet
from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, AbstractPreprocWrapper, BIGMAX
from .nodes_weight import (DefaultWeights, ScaledSoftMaskedUniversalWeights, ScaledSoftUniversalWeights, SoftControlNetWeights, CustomControlNetWeights,
SoftT2IAdapterWeights, CustomT2IAdapterWeights)
from .nodes_keyframes import (LatentKeyframeGroupNode, LatentKeyframeInterpolationNode, LatentKeyframeBatchedGroupNode, LatentKeyframeNode,
Expand Down Expand Up @@ -89,6 +89,7 @@ def INPUT_TYPES(s):
"latent_kf_override": ("LATENT_KEYFRAME", ),
"weights_override": ("CONTROL_NET_WEIGHTS", ),
"model_optional": ("MODEL",),
"vae_optional": ("VAE",),
}
}

Expand All @@ -99,7 +100,7 @@ def INPUT_TYPES(s):
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"

def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent,
mask_optional: Tensor=None, model_optional: ModelPatcher=None,
mask_optional: Tensor=None, model_optional: ModelPatcher=None, vae_optional=None,
timestep_kf: TimestepKeyframeGroup=None, latent_kf_override: LatentKeyframeGroup=None,
weights_override: ControlWeights=None):
if strength == 0:
Expand All @@ -121,7 +122,7 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
c_net = cnets[prev_cnet]
else:
# copy, convert to advanced if needed, and set cond
c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent))
c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent), vae_optional)
if is_advanced_controlnet(c_net):
# disarm node check
c_net.disarm()
Expand All @@ -130,6 +131,17 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
if not model_optional:
raise Exception(f"Type '{type(c_net).__name__}' requires model_optional input, but got None.")
c_net.patch_model(model=model_optional)
# if vae required, verify vae is passed in
if c_net.require_vae:
# if controlnet can accept preprocced condhint latents and is the case, ignore vae requirement
if c_net.allow_condhint_latents and isinstance(control_hint, AbstractPreprocWrapper):
pass
elif not vae_optional:
# make sure SD3 ControlNet will get a special message instead of generic type mention
if is_sd3_advanced_controlnet:
raise Exception(f"SD3 ControlNet requires vae_optional input, but got None.")
else:
raise Exception(f"Type '{type(c_net).__name__}' requires vae_optional input, but got None.")
# apply optional parameters and overrides, if provided
if timestep_kf is not None:
c_net.set_timestep_keyframes(timestep_kf)
Expand Down
2 changes: 2 additions & 0 deletions adv_control/nodes_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
weights = [weight_00, weight_01, weight_02, weight_03]
weights = get_properly_arranged_t2i_weights(weights)
weights.reverse() # to account for recent ComfyUI changes
weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))

Expand Down Expand Up @@ -229,5 +230,6 @@ def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
weights = [weight_00, weight_01, weight_02, weight_03]
weights = get_properly_arranged_t2i_weights(weights)
weights.reverse() # to account for recent ComfyUI changes
weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
Loading