Skip to content

Commit

Permalink
Merge pull request #123 from Kosinkadink/sd3-changes
Browse files Browse the repository at this point in the history
SD3 ControlNet support + New ComfyUI Compatibility
  • Loading branch information
Kosinkadink authored Jun 28, 2024
2 parents bf16347 + 3d00251 commit 7a456aa
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 136 deletions.
102 changes: 57 additions & 45 deletions adv_control/control.py

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions adv_control/control_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def create_combo(reference_type: str, style_fidelity: float, ref_weight: float,


class ReferencePreprocWrapper(AbstractPreprocWrapper):
error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of Reference preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
def __init__(self, condhint: Tensor):
super().__init__(condhint)

Expand All @@ -228,10 +228,12 @@ class ReferenceAdvanced(ControlBase, AdvancedControlBase):

def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup, device=None):
super().__init__(device)
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True)
# TODO: allow vae_optional to be used instead of preprocessor
#require_vae=True
self.ref_opts = ref_opts
self.order = 0
self.latent_format = None
self.model_latent_format = None
self.model_sampling_current = None
self.should_apply_attn_effective_strength = False
self.should_apply_adain_effective_strength = False
Expand Down Expand Up @@ -288,9 +290,9 @@ def should_run(self):

def pre_run_advanced(self, model, percent_to_timestep_function):
AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function)
if type(self.cond_hint_original) == ReferencePreprocWrapper:
if isinstance(self.cond_hint_original, AbstractPreprocWrapper):
self.cond_hint_original = self.cond_hint_original.condhint
self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
self.model_sampling_current = model.model_sampling
# SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments
if type(model).__name__ == "SDXL":
Expand Down Expand Up @@ -328,7 +330,7 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
# noise cond_hint based on sigma (current step)
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.model_latent_format.process_in(self.cond_hint)
self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None)
timestep = self.model_sampling_current.timestep(t)
self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0))
Expand All @@ -343,8 +345,8 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):

def cleanup_advanced(self):
super().cleanup_advanced()
del self.latent_format
self.latent_format = None
del self.model_latent_format
self.model_latent_format = None
del self.model_sampling_current
self.model_sampling_current = None
self.should_apply_attn_effective_strength = False
Expand Down
45 changes: 10 additions & 35 deletions adv_control/control_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from comfy.ldm.modules.attention import SpatialTransformer
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default
from comfy.ldm.modules.attention import FeedForward, SpatialTransformer
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential
from comfy.model_patcher import ModelPatcher
import comfy.ops
import comfy.model_management

from .logger import logger
from .utils import (BIGMAX, TimestepKeyframeGroup, disable_weight_init_clean_groupnorm,
from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm,
prepare_mask_batch, broadcast_image_to_extend, extend_to_batch_size)


Expand Down Expand Up @@ -85,7 +85,8 @@ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs)
x = torch.zeros_like(x)
guided_hint = self.input_hint_block(hint, emb, context)

outs = []
out_output = []
out_middle = []

hs = []
if self.num_classes is not None:
Expand All @@ -100,12 +101,12 @@ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs)
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
out_output.append(zero_conv(h, emb, context))

h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
out_middle.append(self.middle_block_out(h, emb, context))

return outs
return {"middle": out_middle, "output": out_output}


class SparseModelPatcher(ModelPatcher):
Expand Down Expand Up @@ -154,36 +155,10 @@ def clone(self):
self.object_patches_backup = n.object_patches_backup


class PreprocSparseRGBWrapper:
error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
class PreprocSparseRGBWrapper(AbstractPreprocWrapper):
error_msg = error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
def __init__(self, condhint: Tensor):
self.condhint = condhint

def movedim(self, *args, **kwargs):
return self

def __getattr__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __setattr__(self, name, value):
if name != "condhint":
raise AttributeError(self.error_msg)
super().__setattr__(name, value)

def __iter__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __next__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __len__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __getitem__(self, *args, **kwargs):
raise AttributeError(self.error_msg)

def __setitem__(self, *args, **kwargs):
raise AttributeError(self.error_msg)
super().__init__(condhint)


class SparseContextAware:
Expand Down
9 changes: 5 additions & 4 deletions adv_control/control_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):

guided_hint = self.input_hint_block(hint, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)

outs = []
out_output = []
out_middle = []

hs = []
if self.num_classes is not None:
Expand All @@ -326,12 +327,12 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
guided_hint = None
else:
h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
outs.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
out_output.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))

h = self.middle_block(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
outs.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
out_middle.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))

return outs
return {"middle": out_middle, "output": out_output}


TEMPORAL_TRANSFORMER_BLOCKS = {
Expand Down
20 changes: 16 additions & 4 deletions adv_control/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import folder_paths
from comfy.model_patcher import ModelPatcher

from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet
from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, BIGMAX
from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet, is_sd3_advanced_controlnet
from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, AbstractPreprocWrapper, BIGMAX
from .nodes_weight import (DefaultWeights, ScaledSoftMaskedUniversalWeights, ScaledSoftUniversalWeights, SoftControlNetWeights, CustomControlNetWeights,
SoftT2IAdapterWeights, CustomT2IAdapterWeights)
from .nodes_keyframes import (LatentKeyframeGroupNode, LatentKeyframeInterpolationNode, LatentKeyframeBatchedGroupNode, LatentKeyframeNode,
Expand Down Expand Up @@ -89,6 +89,7 @@ def INPUT_TYPES(s):
"latent_kf_override": ("LATENT_KEYFRAME", ),
"weights_override": ("CONTROL_NET_WEIGHTS", ),
"model_optional": ("MODEL",),
"vae_optional": ("VAE",),
}
}

Expand All @@ -99,7 +100,7 @@ def INPUT_TYPES(s):
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"

def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent,
mask_optional: Tensor=None, model_optional: ModelPatcher=None,
mask_optional: Tensor=None, model_optional: ModelPatcher=None, vae_optional=None,
timestep_kf: TimestepKeyframeGroup=None, latent_kf_override: LatentKeyframeGroup=None,
weights_override: ControlWeights=None):
if strength == 0:
Expand All @@ -121,7 +122,7 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
c_net = cnets[prev_cnet]
else:
# copy, convert to advanced if needed, and set cond
c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent))
c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent), vae_optional)
if is_advanced_controlnet(c_net):
# disarm node check
c_net.disarm()
Expand All @@ -130,6 +131,17 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
if not model_optional:
raise Exception(f"Type '{type(c_net).__name__}' requires model_optional input, but got None.")
c_net.patch_model(model=model_optional)
# if vae required, verify vae is passed in
if c_net.require_vae:
# if controlnet can accept preprocced condhint latents and is the case, ignore vae requirement
if c_net.allow_condhint_latents and isinstance(control_hint, AbstractPreprocWrapper):
pass
elif not vae_optional:
# make sure SD3 ControlNet will get a special message instead of generic type mention
if is_sd3_advanced_controlnet:
raise Exception(f"SD3 ControlNet requires vae_optional input, but got None.")
else:
raise Exception(f"Type '{type(c_net).__name__}' requires vae_optional input, but got None.")
# apply optional parameters and overrides, if provided
if timestep_kf is not None:
c_net.set_timestep_keyframes(timestep_kf)
Expand Down
2 changes: 2 additions & 0 deletions adv_control/nodes_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
weights = [weight_00, weight_01, weight_02, weight_03]
weights = get_properly_arranged_t2i_weights(weights)
weights.reverse() # to account for recent ComfyUI changes
weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))

Expand Down Expand Up @@ -229,5 +230,6 @@ def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
weights = [weight_00, weight_01, weight_02, weight_03]
weights = get_properly_arranged_t2i_weights(weights)
weights.reverse() # to account for recent ComfyUI changes
weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
Loading

0 comments on commit 7a456aa

Please sign in to comment.