Skip to content

Commit

Permalink
update loraplus on dylora/lofa_fa
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 6, 2024
1 parent 52e64c6 commit 7fe8150
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 34 deletions.
46 changes: 29 additions & 17 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import torch
from torch import nn
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


class DyLoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
Expand Down Expand Up @@ -195,7 +198,7 @@ def create_network(
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)

if unit is not None:
unit = int(unit)
else:
Expand All @@ -211,6 +214,16 @@ def create_network(
unit=unit,
varbose=True,
)

loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)

return network


Expand Down Expand Up @@ -280,6 +293,10 @@ def __init__(
self.alpha = alpha
self.apply_to_conv = apply_to_conv

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None

if modules_dim is not None:
logger.info("create LoRA network from weights")
else:
Expand Down Expand Up @@ -320,9 +337,9 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
loras.append(lora)
return loras

text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]

self.text_encoder_loras = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
Expand All @@ -331,7 +348,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
else:
index = None
logger.info("create LoRA for Text Encoder")

text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)

Expand All @@ -346,6 +363,11 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
self.unet_loras = create_modules(True, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")

def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio

def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
Expand Down Expand Up @@ -407,15 +429,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
"""

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []

Expand Down Expand Up @@ -452,15 +466,13 @@ def assemble_params(loras, lr, ratio):
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
)
all_params.extend(params)

if self.unet_loras:
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_loraplus_ratio or loraplus_ratio
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio
)
all_params.extend(params)

Expand Down
7 changes: 6 additions & 1 deletion networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,8 @@ def create_network(
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)

if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
Expand Down Expand Up @@ -855,6 +856,10 @@ def __init__(
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None

if modules_dim is not None:
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
Expand Down
52 changes: 36 additions & 16 deletions networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import torch
import re
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
Expand Down Expand Up @@ -504,6 +506,15 @@ def create_network(
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)

loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)

return network


Expand All @@ -529,7 +540,9 @@ def parse_floats(s):
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks

if block_alphas is not None:
Expand Down Expand Up @@ -803,21 +816,31 @@ def __init__(
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None

if modules_dim is not None:
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
logger.info(f"create LoRA network from block_dims")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)

# create module instances
def create_modules(
Expand Down Expand Up @@ -939,6 +962,11 @@ def create_modules(
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)

def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio

def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
Expand Down Expand Up @@ -1033,15 +1061,7 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
return lr_weight

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []

Expand Down Expand Up @@ -1078,7 +1098,7 @@ def assemble_params(loras, lr, ratio):
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
)
all_params.extend(params)

Expand All @@ -1097,15 +1117,15 @@ def assemble_params(loras, lr, ratio):
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_loraplus_ratio or loraplus_ratio
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
)
all_params.extend(params)

else:
params = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
)
all_params.extend(params)

Expand Down

0 comments on commit 7fe8150

Please sign in to comment.