diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index eb9a40b0d74..a5d9e172602 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -35,8 +35,11 @@ def activate(self, p, params_list): names.append(params.positional[0]) te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 te_multiplier = float(params.named.get("te", te_multiplier)) - unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier - unet_multiplier = float(params.named.get("unet", unet_multiplier)) + unet_multiplier = [float(params.positional[2]) if len(params.positional) > 2 else te_multiplier] * 3 + unet_multiplier = [float(params.named.get("unet", unet_multiplier[0]))] * 3 + unet_multiplier[0] = float(params.named.get("in", unet_multiplier[0])) + unet_multiplier[1] = float(params.named.get("mid", unet_multiplier[1])) + unet_multiplier[2] = float(params.named.get("out", unet_multiplier[2])) dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim te_multipliers.append(te_multiplier) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index d7c15877036..e5828daf3f2 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -83,7 +83,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk): self.name = name self.network_on_disk = network_on_disk self.te_multiplier = 1.0 - self.unet_multiplier = 1.0 + self.unet_multiplier = [1.0] * 3 self.dyn_dim = None self.modules = {} self.mtime = None @@ -112,8 +112,14 @@ def __init__(self, net: Network, weights: NetworkWeights): def multiplier(self): if 'transformer' in self.sd_key[:20]: return self.network.te_multiplier + if "down_blocks" in self.sd_key: + return self.network.unet_multiplier[0] + if "mid_block" in self.sd_key: + return self.network.unet_multiplier[1] + if "up_blocks" in self.sd_key: + return self.network.unet_multiplier[2] else: - return self.network.unet_multiplier + return self.network.unet_multiplier[0] def calc_scale(self): if self.scale is not None: