diff --git a/scripts/cfg_combiner.py b/scripts/cfg_combiner.py new file mode 100644 index 0000000..c252d10 --- /dev/null +++ b/scripts/cfg_combiner.py @@ -0,0 +1,232 @@ +import gradio as gr +import logging +import torch +from modules import shared, scripts, devices, patches, script_callbacks +from modules.script_callbacks import CFGDenoiserParams +from modules.processing import StableDiffusionProcessing +from scripts.incantation_base import UIWrapper +from scripts.scfg import scfg_combine_denoised + +logger = logging.getLogger(__name__) + +class CFGCombinerScript(UIWrapper): + """ Some scripts modify the CFGs in ways that are not compatible with each other. + This script will patch the CFG denoiser function to apply CFG in an ordered way. + This script adds a dict named 'incant_cfg_params' to the processing object. + This dict contains the following: + 'denoiser': the denoiser object + 'pag_params': list of PAG parameters + 'scfg_params': the S-CFG parameters + ... + """ + def __init__(self): + pass + + # Extension title in menu UI + def title(self): + return "CFG Combiner" + + # Decide to show menu in txt2img or img2img + def show(self, is_img2img): + return scripts.AlwaysVisible + + # Setup menu ui detail + def setup_ui(self, is_img2img): + self.infotext_fields = [] + self.paste_field_names = [] + return [] + + def before_process(self, p: StableDiffusionProcessing, *args, **kwargs): + logger.debug("CFGCombinerScript before_process") + cfg_dict = { + "denoiser": None, + "pag_params": None, + "scfg_params": None + } + setattr(p, 'incant_cfg_params', cfg_dict) + + def process(self, p: StableDiffusionProcessing, *args, **kwargs): + pass + + def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): + pass + + def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): + """ Process the batch and hook the CFG denoiser if PAG or S-CFG is active """ + logger.debug("CFGCombinerScript process_batch") + pag_active = p.extra_generation_params.get('PAG Active', False) + scfg_active = p.extra_generation_params.get('SCFG Active', False) + + if not any([ + pag_active, + scfg_active + ]): + return + + logger.debug("CFGCombinerScript process_batch: pag_active or scfg_active") + + cfg_denoise_lambda = lambda params: self.on_cfg_denoiser_callback(params, p.incant_cfg_params) + unhook_lambda = lambda: self.unhook_callbacks() + + script_callbacks.on_cfg_denoiser(cfg_denoise_lambda) + script_callbacks.on_script_unloaded(unhook_lambda) + logger.debug('Hooked callbacks') + + def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs): + logger.debug("CFGCombinerScript postprocess_batch") + script_callbacks.remove_current_script_callbacks() + + def unhook_callbacks(self, cfg_dict = None): + if not cfg_dict: + return + self.unpatch_cfg_denoiser(cfg_dict) + + def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, cfg_dict: dict): + """ Callback for when the CFG denoiser is called + Patches the combine_denoised function with a custom one. + """ + if cfg_dict['denoiser'] is None: + cfg_dict['denoiser'] = params.denoiser + else: + self.unpatch_cfg_denoiser(cfg_dict) + self.patch_cfg_denoiser(params.denoiser, cfg_dict) + + def patch_cfg_denoiser(self, denoiser, cfg_dict: dict): + """ Patch the CFG Denoiser combine_denoised function """ + if not cfg_dict: + logger.error("Unable to patch CFG Denoiser, no dict passed as cfg_dict") + return + if not denoiser: + logger.error("Unable to patch CFG Denoiser, denoiser is None") + return + + if getattr(denoiser, 'combine_denoised_patched', False) is False: + try: + setattr(denoiser, 'combine_denoised_original', denoiser.combine_denoised) + # create patch that references the original function + pass_conds_func = lambda *args, **kwargs: combine_denoised_pass_conds_list( + *args, + **kwargs, + original_func = denoiser.combine_denoised_original, + pag_params = cfg_dict['pag_params'], + scfg_params = cfg_dict['scfg_params'] + ) + patched_combine_denoised = patches.patch(__name__, denoiser, "combine_denoised", pass_conds_func) + setattr(denoiser, 'combine_denoised_patched', True) + setattr(denoiser, 'combine_denoised_original', patches.original(__name__, denoiser, "combine_denoised")) + except KeyError: + logger.exception("KeyError patching combine_denoised") + pass + except RuntimeError: + logger.exception("RuntimeError patching combine_denoised") + pass + + def unpatch_cfg_denoiser(self, cfg_dict = None): + """ Unpatch the CFG Denoiser combine_denoised function """ + if cfg_dict is None: + return + denoiser = cfg_dict.get('denoiser', None) + if denoiser is None: + return + + setattr(denoiser, 'combine_denoised_patched', False) + try: + patches.undo(__name__, denoiser, "combine_denoised") + except KeyError: + logger.exception("KeyError unhooking combine_denoised") + pass + except RuntimeError: + logger.exception("RuntimeError unhooking combine_denoised") + pass + + cfg_dict['denoiser'] = None + + +def combine_denoised_pass_conds_list(*args, **kwargs): + """ Hijacked function for combine_denoised in CFGDenoiser + Currently relies on the original function not having any kwargs + If any of the params are not None, it will apply the corresponding guidance + The order of guidance is: + 1. CFG and S-CFG are combined multiplicatively + 2. PAG guidance is added to the result + 3. ... + ... + """ + original_func = kwargs.get('original_func', None) + pag_params = kwargs.get('pag_params', None) + scfg_params = kwargs.get('scfg_params', None) + + if pag_params is None and scfg_params is None: + logger.warning("No reason to hijack combine_denoised") + return original_func(*args) + + def new_combine_denoised(x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + ### Variables + # 0. Standard CFG Value + cfg_scale = cond_scale + + # 1. CFG Interval + # Overrides cfg_scale if pag_params is not None + if pag_params is not None: + if pag_params.cfg_interval_enable: + cfg_scale = pag_params.cfg_interval_scheduled_value + + # 2. PAG + pag_x_out = None + pag_scale = None + if pag_params is not None: + pag_x_out = pag_params.pag_x_out + pag_scale = pag_params.pag_scale + + ### Combine Denoised + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + + model_delta = x_out[cond_index] - denoised_uncond[i] + + # S-CFG + rate = 1.0 + if scfg_params is not None: + rate = scfg_combine_denoised( + model_delta = model_delta, + cfg_scale = cfg_scale, + scfg_params = scfg_params, + ) + # If rate is not an int, convert to tensor + if rate is None: + logger.error("scfg_combine_denoised returned None, using default rate of 1.0") + rate = 1.0 + elif not isinstance(rate, int) and not isinstance(rate, float): + rate = rate.to(device=shared.device, dtype=model_delta.dtype) + else: + # rate is tensor, probably + pass + + # 1. Experimental formulation for S-CFG combined with CFG + denoised[i] += (model_delta) * rate * (weight * cfg_scale) + del rate + + # 2. PAG + # PAG is added like CFG + if pag_params is not None: + # Not within step interval? + if not pag_params.pag_start_step <= pag_params.step <= pag_params.pag_end_step: + pass + # Scale is zero? + elif pag_scale <= 0: + pass + # do pag + else: + try: + denoised[i] += (x_out[cond_index] - pag_x_out[i]) * (weight * pag_scale) + except Exception as e: + logger.exception("Exception in combine_denoised_pass_conds_list - %s", e) + + #torch.cuda.empty_cache() + devices.torch_gc() + + return denoised + return new_combine_denoised(*args) \ No newline at end of file diff --git a/scripts/incantation_base.py b/scripts/incantation_base.py index 87298ab..48c4fa1 100644 --- a/scripts/incantation_base.py +++ b/scripts/incantation_base.py @@ -13,6 +13,7 @@ from scripts.scfg import SCFGExtensionScript from scripts.pag import PAGExtensionScript from scripts.save_attn_maps import SaveAttentionMapsScript +from scripts.cfg_combiner import CFGCombinerScript logger = logging.getLogger(__name__) logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) @@ -43,7 +44,13 @@ def __init__(self, module: UIWrapper, module_idx = 0, num_args = -1, arg_idx = - submodules.append(SubmoduleInfo(module=SaveAttentionMapsScript())) else: logger.info("Incantation: Debug scripts are disabled. Set INCANT_DEBUG environment variable to enable them.") - +# run these after submodules +end_submodules: list[SubmoduleInfo] = [ + SubmoduleInfo(module=CFGCombinerScript()) +] +submodules = submodules + end_submodules + + class IncantBaseExtensionScript(scripts.Script): def __init__(self): pass @@ -124,7 +131,11 @@ def callback_before_ui(): try: for module_info in submodules: module = module_info.module - extra_axis_options = module.get_xyz_axis_options() + try: + extra_axis_options = module.get_xyz_axis_options() + except NotImplementedError: + logger.warning(f"Module {module.title()} does not implement get_xyz_axis_options") + extra_axis_options = {} make_axis_options(extra_axis_options) except: logger.exception("Incantation: Error while making axis options") diff --git a/scripts/pag.py b/scripts/pag.py index d6606c3..2a8ef60 100644 --- a/scripts/pag.py +++ b/scripts/pag.py @@ -105,9 +105,11 @@ def __init__(self): self.cfg_interval_schedule: str = 'Constant' self.cfg_interval_low: float = 0 self.cfg_interval_high: float = 50.0 + self.cfg_interval_scheduled_value: float = 7.0 self.step : int = 0 self.max_sampling_step : int = 1 self.guidance_scale: int = -1 # CFG + self.current_noise_level: float = 100.0 self.x_in = None self.text_cond = None self.image_cond = None @@ -224,6 +226,12 @@ def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, sta def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, *args, **kwargs): # Create a list of parameters for each concept pag_params = PAGStateParams() + + # Add to p's incant_cfg_params + if not hasattr(p, 'incant_cfg_params'): + logger.error("No incant_cfg_params found in p") + p.incant_cfg_params['pag_params'] = pag_params + pag_params.pag_scale = pag_scale pag_params.pag_start_step = start_step pag_params.pag_end_step = end_step @@ -233,6 +241,7 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste pag_params.guidance_scale = p.cfg_scale pag_params.batch_size = p.batch_size pag_params.denoiser = None + pag_params.cfg_interval_scheduled_value = p.cfg_scale if pag_params.cfg_interval_enable: # Refer to 3.1 Practice in the paper @@ -264,6 +273,8 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste #script_callbacks.on_cfg_after_cfg(after_cfg_lambda) script_callbacks.on_script_unloaded(unhook_lambda) + + def postprocess_batch(self, p, *args, **kwargs): self.pag_postprocess_batch(p, *args, **kwargs) @@ -287,6 +298,7 @@ def remove_all_hooks(self): def unhook_callbacks(self, pag_params: PAGStateParams): global handles + return if pag_params is None: logger.error("PAG params is None") @@ -388,27 +400,24 @@ def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, pag_params: PAGSta pag_params.step = params.sampling_step - # patch combine_denoised - if pag_params.denoiser is None: - pag_params.denoiser = params.denoiser - if getattr(params.denoiser, 'combine_denoised_patched', False) is False: - try: - setattr(params.denoiser, 'combine_denoised_original', params.denoiser.combine_denoised) - # create patch that references the original function - pass_conds_func = lambda *args, **kwargs: combine_denoised_pass_conds_list( - *args, - **kwargs, - original_func = params.denoiser.combine_denoised_original, - pag_params = pag_params) - pag_params.patched_combine_denoised = patches.patch(__name__, params.denoiser, "combine_denoised", pass_conds_func) - setattr(params.denoiser, 'combine_denoised_patched', True) - setattr(params.denoiser, 'combine_denoised_original', patches.original(__name__, params.denoiser, "combine_denoised")) - except KeyError: - logger.exception("KeyError patching combine_denoised") - pass - except RuntimeError: - logger.exception("RuntimeError patching combine_denoised") - pass + # CFG Interval + # TODO: set rho based on sdxl or sd1.5 + pag_params.current_noise_level = calculate_noise_level( + i = pag_params.step, + N = pag_params.max_sampling_step, + ) + + if pag_params.cfg_interval_enable: + if pag_params.cfg_interval_schedule != 'Constant': + # Calculate noise interval + start = pag_params.cfg_interval_low + end = pag_params.cfg_interval_high + begin_range = start if start <= end else end + end_range = end if start <= end else start + # Scheduled CFG Value + scheduled_cfg_scale = cfg_scheduler(pag_params.cfg_interval_schedule, pag_params.step, pag_params.max_sampling_step, pag_params.guidance_scale) + + pag_params.cfg_interval_scheduled_value = scheduled_cfg_scale if begin_range <= pag_params.current_noise_level <= end_range else 1.0 # Run only within interval if not pag_params.pag_start_step <= params.sampling_step <= pag_params.pag_end_step or pag_params.pag_scale <= 0: @@ -468,6 +477,8 @@ def on_cfg_denoised_callback(self, params: CFGDenoisedParams, pag_params: PAGSta # set pag_enable to False for module in pag_params.crossattn_modules: setattr(module, 'pag_enable', False) + + def cfg_after_cfg_callback(self, params: AfterCFGCallbackParams, pag_params: PAGStateParams): #self.unhook_callbacks(pag_params) @@ -506,6 +517,8 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale): # Calculate CFG Scale cfg_scale = cond_scale + new_params.cfg_interval_scheduled_value = cfg_scale + if new_params.cfg_interval_enable: if new_params.cfg_interval_schedule != 'Constant': # Calculate noise interval @@ -517,6 +530,11 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale): scheduled_cfg_scale = cfg_scheduler(new_params.cfg_interval_schedule, new_params.step, new_params.max_sampling_step, cond_scale) # Only apply CFG in the interval cfg_scale = scheduled_cfg_scale if begin_range <= noise_level <= end_range else 1.0 + new_params.cfg_interval_scheduled_value = scheduled_cfg_scale + + # This may be temporarily necessary for compatibility with scfg + # if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step: + # return original_func(*args) # This may be temporarily necessary for compatibility with scfg # if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step: diff --git a/scripts/scfg.py b/scripts/scfg.py index 7f87f0d..afd2367 100644 --- a/scripts/scfg.py +++ b/scripts/scfg.py @@ -185,6 +185,12 @@ def pag_process_batch(self, p: StableDiffusionProcessing, active, scfg_scale, sc def create_hook(self, p: StableDiffusionProcessing, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r): # Create a list of parameters for each concept scfg_params = SCFGStateParams() + + # Add to p + if not hasattr(p, 'incant_cfg_params'): + logger.error("No incant_cfg_params found in p") + p.incant_cfg_params['scfg_params'] = scfg_params + scfg_params.denoiser = None scfg_params.all_crossattn_modules = self.get_all_crossattn_modules() scfg_params.max_sampling_steps = p.steps @@ -199,14 +205,14 @@ def create_hook(self, p: StableDiffusionProcessing, active, scfg_scale, scfg_rat scfg_params.width = p.width # Use lambda to call the callback function with the parameters to avoid global variables - cfg_denoise_lambda = lambda callback_params: self.on_cfg_denoiser_callback(callback_params, scfg_params) + #cfg_denoise_lambda = lambda callback_params: self.on_cfg_denoiser_callback(callback_params, scfg_params) cfg_denoised_lambda = lambda callback_params: self.on_cfg_denoised_callback(callback_params, scfg_params) unhook_lambda = lambda _: self.unhook_callbacks(scfg_params) self.ready_hijack_forward(scfg_params.all_crossattn_modules) logger.debug('Hooked callbacks') - script_callbacks.on_cfg_denoiser(cfg_denoise_lambda) + #script_callbacks.on_cfg_denoiser(cfg_denoise_lambda) script_callbacks.on_cfg_denoised(cfg_denoised_lambda) script_callbacks.on_script_unloaded(unhook_lambda) @@ -238,21 +244,21 @@ def unhook_callbacks(self, scfg_params: SCFGStateParams): global handles if scfg_params is None: - logger.error("PAG params is None") + logger.error("SCFG params is None") return - if scfg_params.denoiser is not None: - denoiser = scfg_params.denoiser - setattr(denoiser, 'combine_denoised_patched_scfg', False) - try: - patches.undo(__name__, denoiser, "combine_denoised") - except KeyError: - logger.exception("KeyError unhooking combine_denoised") - pass - except RuntimeError: - logger.exception("RuntimeError unhooking combine_denoised") - pass - scfg_params.denoiser = None + #if scfg_params.denoiser is not None: + # denoiser = scfg_params.denoiser + # setattr(denoiser, 'combine_denoised_patched_scfg', False) + # try: + # patches.undo(__name__, denoiser, "combine_denoised") + # except KeyError: + # logger.exception("KeyError unhooking combine_denoised") + # pass + # except RuntimeError: + # logger.exception("RuntimeError unhooking combine_denoised") + # pass + # scfg_params.denoiser = None def ready_hijack_forward(self, all_crossattn_modules): @@ -317,26 +323,26 @@ def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, scfg_params: SCFGS self.unhook_callbacks(scfg_params) # patch combine_denoised - if scfg_params.denoiser is None: - scfg_params.denoiser = params.denoiser - if getattr(params.denoiser, 'combine_denoised_patched_scfg', False) is False: - try: - setattr(params.denoiser, 'combine_denoised_original_scfg', params.denoiser.combine_denoised) - # create patch that references the original function - pass_conds_func = lambda *args, **kwargs: combine_denoised_pass_conds_list( - *args, - **kwargs, - original_func = params.denoiser.combine_denoised_original_scfg, - scfg_params = scfg_params) - scfg_params.patched_combine_denoised = patches.patch(__name__, params.denoiser, "combine_denoised", pass_conds_func) - setattr(params.denoiser, 'combine_denoised_patched_scfg', True) - setattr(params.denoiser, 'combine_denoised_original_scfg', patches.original(__name__, params.denoiser, "combine_denoised")) - except KeyError: - logger.exception("KeyError patching combine_denoised") - pass - except RuntimeError: - logger.exception("RuntimeError patching combine_denoised") - pass + # if scfg_params.denoiser is None: + # scfg_params.denoiser = params.denoiser + # if getattr(params.denoiser, 'combine_denoised_patched_scfg', False) is False: + # try: + # setattr(params.denoiser, 'combine_denoised_original_scfg', params.denoiser.combine_denoised) + # # create patch that references the original function + # pass_conds_func = lambda *args, **kwargs: combine_denoised_pass_conds_list( + # *args, + # **kwargs, + # original_func = params.denoiser.combine_denoised_original_scfg, + # scfg_params = scfg_params) + # scfg_params.patched_combine_denoised = patches.patch(__name__, params.denoiser, "combine_denoised", pass_conds_func) + # setattr(params.denoiser, 'combine_denoised_patched_scfg', True) + # setattr(params.denoiser, 'combine_denoised_original_scfg', patches.original(__name__, params.denoiser, "combine_denoised")) + # except KeyError: + # logger.exception("KeyError patching combine_denoised") + # pass + # except RuntimeError: + # logger.exception("RuntimeError patching combine_denoised") + # pass def on_cfg_denoised_callback(self, params: CFGDenoisedParams, scfg_params: SCFGStateParams): """ Callback function for the CFGDenoisedParams @@ -348,6 +354,9 @@ def on_cfg_denoised_callback(self, params: CFGDenoisedParams, scfg_params: SCFGS # Run only within interval if not scfg_params.start_step <= params.sampling_step <= scfg_params.end_step: return + + if scfg_params.scfg_scale <= 0: + return # S-CFG R = scfg_params.R @@ -381,6 +390,79 @@ def get_xyz_axis_options(self) -> dict: return extra_axis_options +def scfg_combine_denoised(model_delta, cfg_scale, scfg_params: SCFGStateParams): + """ The inner loop of the S-CFG denoiser + Arguments: + model_delta: torch.Tensor - defined by `x_out[cond_index] - denoised_uncond[i]` + cfg_scale: float - guidance scale + scfg_params: SCFGStateParams - the state parameters for the S-CFG denoiser + + Returns: + int or torch.Tensor - 1.0 if not within interval or scale is 0, else the rate map tensor + """ + + current_step = scfg_params.current_step + start_step = scfg_params.start_step + end_step = scfg_params.end_step + scfg_scale = scfg_params.scfg_scale + + if not start_step <= current_step <= end_step: + return 1.0 + + if scfg_scale <= 0: + return 1.0 + + + mask_t = scfg_params.mask_t + mask_fore = scfg_params.mask_fore + min_rate = scfg_params.rate_min + max_rate = scfg_params.rate_max + rate_clamp = scfg_params.rate_clamp + + + model_delta = model_delta.unsqueeze(0) + model_delta_norm = model_delta.norm(dim=1, keepdim=True) + + eps = lambda dtype: torch.finfo(dtype).eps + + # rescale map if necessary + if mask_t.shape[2:] != model_delta_norm.shape[2:]: + logger.debug('Rescaling mask_t from %s to %s', mask_t.shape[2:], model_delta_norm.shape[2:]) + mask_t = F.interpolate(mask_t, size=model_delta_norm.shape[2:], mode='bilinear') + if mask_fore.shape[-2] != model_delta_norm.shape[-2]: + logger.debug('Rescaling mask_fore from %s to %s', mask_fore.shape[2:], model_delta_norm.shape[2:]) + mask_fore = F.interpolate(mask_fore, size=model_delta_norm.shape[2:], mode='bilinear') + + delta_mask_norms = (model_delta_norm * mask_t).sum([2,3])/(mask_t.sum([2,3])+eps(mask_t.dtype)) + upnormmax = delta_mask_norms.max(dim=1)[0] + upnormmax = upnormmax.unsqueeze(-1) + + fore_norms = (model_delta_norm * mask_fore).sum([2,3])/(mask_fore.sum([2,3])+eps(mask_fore.dtype)) + + up = fore_norms + down = delta_mask_norms + + tmp_mask = (mask_t.sum([2,3])>0).float() + rate = up*(tmp_mask)/(down+eps(down.dtype)) # b 257 + rate = (rate.unsqueeze(-1).unsqueeze(-1)*mask_t).sum(dim=1, keepdim=True) # b 1, 64 64 + del model_delta_norm, delta_mask_norms, upnormmax, fore_norms, up, down, tmp_mask + + # should this go before or after the gaussian blur, or before/after the rate + rate = rate * scfg_scale + + rate = torch.clamp(rate,min=min_rate, max=max_rate) + rate = torch.clamp_max(rate, rate_clamp/cfg_scale) + + ###Gaussian Smoothing + kernel_size = 3 + sigma=0.5 + smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(rate.device) + rate = F.pad(rate, (1, 1, 1, 1), mode='reflect') + rate = smoothing(rate) + + return rate.squeeze(0) + + def combine_denoised_pass_conds_list(*args, **kwargs): """ Hijacked function for combine_denoised in CFGDenoiser """ original_func = kwargs.get('original_func', None) @@ -797,8 +879,8 @@ def get_mask(attn_modules, scfg_params: SCFGStateParams, r, latent_size): ca = smoothing(ca.float()).squeeze(1) ca = rearrange(ca, ' (b c) h w -> b c h w' , c= channel) - ca_norm = ca/(ca.mean(dim=[2,3], keepdim=True)+1e-8) ### spatial normlization - + ca_norm = ca/(ca.mean(dim=[2,3], keepdim=True)+torch.finfo(ca.dtype).eps) ### spatial normlization + new_ca+=rearrange(ca_norm, '(b n) c h w -> b n c h w', n=attn_num).sum(1) fore_ca = torch.stack([ca[:,0],ca[:,1:].sum(dim=1)], dim=1)