Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple scripts modifying latent #36

Merged
merged 15 commits into from
May 14, 2024
232 changes: 232 additions & 0 deletions scripts/cfg_combiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import gradio as gr
import logging
import torch
from modules import shared, scripts, devices, patches, script_callbacks
from modules.script_callbacks import CFGDenoiserParams
from modules.processing import StableDiffusionProcessing
from scripts.incantation_base import UIWrapper
from scripts.scfg import scfg_combine_denoised

logger = logging.getLogger(__name__)

class CFGCombinerScript(UIWrapper):
""" Some scripts modify the CFGs in ways that are not compatible with each other.
This script will patch the CFG denoiser function to apply CFG in an ordered way.
This script adds a dict named 'incant_cfg_params' to the processing object.
This dict contains the following:
'denoiser': the denoiser object
'pag_params': list of PAG parameters
'scfg_params': the S-CFG parameters
...
"""
def __init__(self):
pass

# Extension title in menu UI
def title(self):
return "CFG Combiner"

# Decide to show menu in txt2img or img2img
def show(self, is_img2img):
return scripts.AlwaysVisible

# Setup menu ui detail
def setup_ui(self, is_img2img):
self.infotext_fields = []
self.paste_field_names = []
return []

def before_process(self, p: StableDiffusionProcessing, *args, **kwargs):
logger.debug("CFGCombinerScript before_process")
cfg_dict = {
"denoiser": None,
"pag_params": None,
"scfg_params": None
}
setattr(p, 'incant_cfg_params', cfg_dict)

def process(self, p: StableDiffusionProcessing, *args, **kwargs):
pass

def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
pass

def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
""" Process the batch and hook the CFG denoiser if PAG or S-CFG is active """
logger.debug("CFGCombinerScript process_batch")
pag_active = p.extra_generation_params.get('PAG Active', False)
scfg_active = p.extra_generation_params.get('SCFG Active', False)

if not any([
pag_active,
scfg_active
]):
return

logger.debug("CFGCombinerScript process_batch: pag_active or scfg_active")

cfg_denoise_lambda = lambda params: self.on_cfg_denoiser_callback(params, p.incant_cfg_params)
unhook_lambda = lambda: self.unhook_callbacks()

script_callbacks.on_cfg_denoiser(cfg_denoise_lambda)
script_callbacks.on_script_unloaded(unhook_lambda)
logger.debug('Hooked callbacks')

def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
logger.debug("CFGCombinerScript postprocess_batch")
script_callbacks.remove_current_script_callbacks()

def unhook_callbacks(self, cfg_dict = None):
if not cfg_dict:
return
self.unpatch_cfg_denoiser(cfg_dict)

def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, cfg_dict: dict):
""" Callback for when the CFG denoiser is called
Patches the combine_denoised function with a custom one.
"""
if cfg_dict['denoiser'] is None:
cfg_dict['denoiser'] = params.denoiser
else:
self.unpatch_cfg_denoiser(cfg_dict)
self.patch_cfg_denoiser(params.denoiser, cfg_dict)

def patch_cfg_denoiser(self, denoiser, cfg_dict: dict):
""" Patch the CFG Denoiser combine_denoised function """
if not cfg_dict:
logger.error("Unable to patch CFG Denoiser, no dict passed as cfg_dict")
return
if not denoiser:
logger.error("Unable to patch CFG Denoiser, denoiser is None")
return

if getattr(denoiser, 'combine_denoised_patched', False) is False:
try:
setattr(denoiser, 'combine_denoised_original', denoiser.combine_denoised)
# create patch that references the original function
pass_conds_func = lambda *args, **kwargs: combine_denoised_pass_conds_list(
*args,
**kwargs,
original_func = denoiser.combine_denoised_original,
pag_params = cfg_dict['pag_params'],
scfg_params = cfg_dict['scfg_params']
)
patched_combine_denoised = patches.patch(__name__, denoiser, "combine_denoised", pass_conds_func)
setattr(denoiser, 'combine_denoised_patched', True)
setattr(denoiser, 'combine_denoised_original', patches.original(__name__, denoiser, "combine_denoised"))
except KeyError:
logger.exception("KeyError patching combine_denoised")
pass
except RuntimeError:
logger.exception("RuntimeError patching combine_denoised")
pass

def unpatch_cfg_denoiser(self, cfg_dict = None):
""" Unpatch the CFG Denoiser combine_denoised function """
if cfg_dict is None:
return
denoiser = cfg_dict.get('denoiser', None)
if denoiser is None:
return

setattr(denoiser, 'combine_denoised_patched', False)
try:
patches.undo(__name__, denoiser, "combine_denoised")
except KeyError:
logger.exception("KeyError unhooking combine_denoised")
pass
except RuntimeError:
logger.exception("RuntimeError unhooking combine_denoised")
pass

cfg_dict['denoiser'] = None


def combine_denoised_pass_conds_list(*args, **kwargs):
""" Hijacked function for combine_denoised in CFGDenoiser
Currently relies on the original function not having any kwargs
If any of the params are not None, it will apply the corresponding guidance
The order of guidance is:
1. CFG and S-CFG are combined multiplicatively
2. PAG guidance is added to the result
3. ...
...
"""
original_func = kwargs.get('original_func', None)
pag_params = kwargs.get('pag_params', None)
scfg_params = kwargs.get('scfg_params', None)

if pag_params is None and scfg_params is None:
logger.warning("No reason to hijack combine_denoised")
return original_func(*args)

def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)

### Variables
# 0. Standard CFG Value
cfg_scale = cond_scale

# 1. CFG Interval
# Overrides cfg_scale if pag_params is not None
if pag_params is not None:
if pag_params.cfg_interval_enable:
cfg_scale = pag_params.cfg_interval_scheduled_value

# 2. PAG
pag_x_out = None
pag_scale = None
if pag_params is not None:
pag_x_out = pag_params.pag_x_out
pag_scale = pag_params.pag_scale

### Combine Denoised
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:

model_delta = x_out[cond_index] - denoised_uncond[i]

# S-CFG
rate = 1.0
if scfg_params is not None:
rate = scfg_combine_denoised(
model_delta = model_delta,
cfg_scale = cfg_scale,
scfg_params = scfg_params,
)
# If rate is not an int, convert to tensor
if rate is None:
logger.error("scfg_combine_denoised returned None, using default rate of 1.0")
rate = 1.0
elif not isinstance(rate, int) and not isinstance(rate, float):
rate = rate.to(device=shared.device, dtype=model_delta.dtype)
else:
# rate is tensor, probably
pass

# 1. Experimental formulation for S-CFG combined with CFG
denoised[i] += (model_delta) * rate * (weight * cfg_scale)
del rate

# 2. PAG
# PAG is added like CFG
if pag_params is not None:
# Not within step interval?
if not pag_params.pag_start_step <= pag_params.step <= pag_params.pag_end_step:
pass
# Scale is zero?
elif pag_scale <= 0:
pass
# do pag
else:
try:
denoised[i] += (x_out[cond_index] - pag_x_out[i]) * (weight * pag_scale)
except Exception as e:
logger.exception("Exception in combine_denoised_pass_conds_list - %s", e)

#torch.cuda.empty_cache()
devices.torch_gc()

return denoised
return new_combine_denoised(*args)
15 changes: 13 additions & 2 deletions scripts/incantation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scripts.scfg import SCFGExtensionScript
from scripts.pag import PAGExtensionScript
from scripts.save_attn_maps import SaveAttentionMapsScript
from scripts.cfg_combiner import CFGCombinerScript

logger = logging.getLogger(__name__)
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO))
Expand Down Expand Up @@ -43,7 +44,13 @@ def __init__(self, module: UIWrapper, module_idx = 0, num_args = -1, arg_idx = -
submodules.append(SubmoduleInfo(module=SaveAttentionMapsScript()))
else:
logger.info("Incantation: Debug scripts are disabled. Set INCANT_DEBUG environment variable to enable them.")

# run these after submodules
end_submodules: list[SubmoduleInfo] = [
SubmoduleInfo(module=CFGCombinerScript())
]
submodules = submodules + end_submodules


class IncantBaseExtensionScript(scripts.Script):
def __init__(self):
pass
Expand Down Expand Up @@ -124,7 +131,11 @@ def callback_before_ui():
try:
for module_info in submodules:
module = module_info.module
extra_axis_options = module.get_xyz_axis_options()
try:
extra_axis_options = module.get_xyz_axis_options()
except NotImplementedError:
logger.warning(f"Module {module.title()} does not implement get_xyz_axis_options")
extra_axis_options = {}
make_axis_options(extra_axis_options)
except:
logger.exception("Incantation: Error while making axis options")
Expand Down
60 changes: 39 additions & 21 deletions scripts/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def __init__(self):
self.cfg_interval_schedule: str = 'Constant'
self.cfg_interval_low: float = 0
self.cfg_interval_high: float = 50.0
self.cfg_interval_scheduled_value: float = 7.0
self.step : int = 0
self.max_sampling_step : int = 1
self.guidance_scale: int = -1 # CFG
self.current_noise_level: float = 100.0
self.x_in = None
self.text_cond = None
self.image_cond = None
Expand Down Expand Up @@ -224,6 +226,12 @@ def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, sta
def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, *args, **kwargs):
# Create a list of parameters for each concept
pag_params = PAGStateParams()

# Add to p's incant_cfg_params
if not hasattr(p, 'incant_cfg_params'):
logger.error("No incant_cfg_params found in p")
p.incant_cfg_params['pag_params'] = pag_params

pag_params.pag_scale = pag_scale
pag_params.pag_start_step = start_step
pag_params.pag_end_step = end_step
Expand All @@ -233,6 +241,7 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste
pag_params.guidance_scale = p.cfg_scale
pag_params.batch_size = p.batch_size
pag_params.denoiser = None
pag_params.cfg_interval_scheduled_value = p.cfg_scale

if pag_params.cfg_interval_enable:
# Refer to 3.1 Practice in the paper
Expand Down Expand Up @@ -264,6 +273,8 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste
#script_callbacks.on_cfg_after_cfg(after_cfg_lambda)
script_callbacks.on_script_unloaded(unhook_lambda)



def postprocess_batch(self, p, *args, **kwargs):
self.pag_postprocess_batch(p, *args, **kwargs)

Expand All @@ -287,6 +298,7 @@ def remove_all_hooks(self):

def unhook_callbacks(self, pag_params: PAGStateParams):
global handles
return

if pag_params is None:
logger.error("PAG params is None")
Expand Down Expand Up @@ -388,27 +400,24 @@ def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, pag_params: PAGSta

pag_params.step = params.sampling_step

# patch combine_denoised
if pag_params.denoiser is None:
pag_params.denoiser = params.denoiser
if getattr(params.denoiser, 'combine_denoised_patched', False) is False:
try:
setattr(params.denoiser, 'combine_denoised_original', params.denoiser.combine_denoised)
# create patch that references the original function
pass_conds_func = lambda *args, **kwargs: combine_denoised_pass_conds_list(
*args,
**kwargs,
original_func = params.denoiser.combine_denoised_original,
pag_params = pag_params)
pag_params.patched_combine_denoised = patches.patch(__name__, params.denoiser, "combine_denoised", pass_conds_func)
setattr(params.denoiser, 'combine_denoised_patched', True)
setattr(params.denoiser, 'combine_denoised_original', patches.original(__name__, params.denoiser, "combine_denoised"))
except KeyError:
logger.exception("KeyError patching combine_denoised")
pass
except RuntimeError:
logger.exception("RuntimeError patching combine_denoised")
pass
# CFG Interval
# TODO: set rho based on sdxl or sd1.5
pag_params.current_noise_level = calculate_noise_level(
i = pag_params.step,
N = pag_params.max_sampling_step,
)

if pag_params.cfg_interval_enable:
if pag_params.cfg_interval_schedule != 'Constant':
# Calculate noise interval
start = pag_params.cfg_interval_low
end = pag_params.cfg_interval_high
begin_range = start if start <= end else end
end_range = end if start <= end else start
# Scheduled CFG Value
scheduled_cfg_scale = cfg_scheduler(pag_params.cfg_interval_schedule, pag_params.step, pag_params.max_sampling_step, pag_params.guidance_scale)

pag_params.cfg_interval_scheduled_value = scheduled_cfg_scale if begin_range <= pag_params.current_noise_level <= end_range else 1.0

# Run only within interval
if not pag_params.pag_start_step <= params.sampling_step <= pag_params.pag_end_step or pag_params.pag_scale <= 0:
Expand Down Expand Up @@ -468,6 +477,8 @@ def on_cfg_denoised_callback(self, params: CFGDenoisedParams, pag_params: PAGSta
# set pag_enable to False
for module in pag_params.crossattn_modules:
setattr(module, 'pag_enable', False)



def cfg_after_cfg_callback(self, params: AfterCFGCallbackParams, pag_params: PAGStateParams):
#self.unhook_callbacks(pag_params)
Expand Down Expand Up @@ -506,6 +517,8 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):

# Calculate CFG Scale
cfg_scale = cond_scale
new_params.cfg_interval_scheduled_value = cfg_scale

if new_params.cfg_interval_enable:
if new_params.cfg_interval_schedule != 'Constant':
# Calculate noise interval
Expand All @@ -517,6 +530,11 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
scheduled_cfg_scale = cfg_scheduler(new_params.cfg_interval_schedule, new_params.step, new_params.max_sampling_step, cond_scale)
# Only apply CFG in the interval
cfg_scale = scheduled_cfg_scale if begin_range <= noise_level <= end_range else 1.0
new_params.cfg_interval_scheduled_value = scheduled_cfg_scale

# This may be temporarily necessary for compatibility with scfg
# if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step:
# return original_func(*args)

# This may be temporarily necessary for compatibility with scfg
# if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step:
Expand Down
Loading