Skip to content

Commit

Permalink
Fix bad Lora parameters (#2341)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais authored Apr 19, 2024
1 parent b4f1589 commit 98d826c
Showing 1 changed file with 38 additions and 41 deletions.
79 changes: 38 additions & 41 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,32 +69,6 @@
]


def update_network_args_with_kohya_lora_vars(
network_args: str, kohya_lora_var_list: list, vars: dict
) -> str:
"""
Update network arguments with Kohya LoRA variables.
Args:
network_args (str): The network arguments.
kohya_lora_var_list (list): The list of Kohya LoRA variables.
vars (dict): The dictionary of variables.
Returns:
str: The updated network arguments.
"""
# Filter out variables that are in the Kohya LoRA variable list and have a value
kohya_lora_vars = {
key: value for key, value in vars if key in kohya_lora_var_list and value
}

# Iterate over the Kohya LoRA variables and append them to the network arguments
for key, value in kohya_lora_vars.items():
# Append each variable as a key-value pair to the network_args
network_args += f" {key}={value}"
return network_args


def save_configuration(
save_as_bool,
file_path,
Expand Down Expand Up @@ -960,11 +934,17 @@ def train_model(
"module_dropout",
]
network_module = "networks.lora"
network_args += update_network_args_with_kohya_lora_vars(
network_args=network_args,
kohya_lora_var_list=kohya_lora_var_list,
vars=vars().items(),
)
kohya_lora_vars = {
key: value
for key, value in vars().items()
if key in kohya_lora_var_list and value
}
if LoRA_type == "Kohya LoCon":
network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"'

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'

if LoRA_type in ["LoRA-FA"]:
kohya_lora_var_list = [
Expand All @@ -979,12 +959,21 @@ def train_model(
"rank_dropout",
"module_dropout",
]

network_module = "networks.lora_fa"
network_args += update_network_args_with_kohya_lora_vars(
network_args=network_args,
kohya_lora_var_list=kohya_lora_var_list,
vars=vars().items(),
)
kohya_lora_vars = {
key: value
for key, value in vars().items()
if key in kohya_lora_var_list and value
}

network_args = ""
if LoRA_type == "Kohya LoCon":
network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"'

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'

if LoRA_type in ["Kohya DyLoRA"]:
kohya_lora_var_list = [
Expand All @@ -1002,12 +991,20 @@ def train_model(
"module_dropout",
"unit",
]

network_module = "networks.dylora"
network_args += update_network_args_with_kohya_lora_vars(
network_args=network_args,
kohya_lora_var_list=kohya_lora_var_list,
vars=vars().items(),
)
kohya_lora_vars = {
key: value
for key, value in vars().items()
if key in kohya_lora_var_list and value
}

network_args = ""

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'

# Convert learning rates to float once and store the result for re-use
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
text_encoder_lr_float = (
Expand Down

0 comments on commit 98d826c

Please sign in to comment.