Skip to content

Commit

Permalink
Fix unset or invalid LR from making a param_group
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Apr 11, 2024
1 parent 75833e8 commit 68467bd
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def prepare_optimizer_params(
text_encoder_lr,
unet_lr,
default_lr,
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
Expand Down Expand Up @@ -441,7 +441,7 @@ def assemble_params(loras, lr, ratio):
else:
param_data["lr"] = lr

if ("lr" in param_data) and (param_data["lr"] == 0):
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue

params.append(param_data)
Expand Down
5 changes: 3 additions & 2 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,8 +1040,8 @@ def prepare_optimizer_params(
text_encoder_lr,
unet_lr,
default_lr,
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
Expand Down Expand Up @@ -1069,7 +1069,8 @@ def assemble_params(loras, lr, ratio):
else:
param_data["lr"] = lr

if ("lr" in param_data) and (param_data["lr"] == 0):
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
print("NO LR skipping!")
continue

params.append(param_data)
Expand Down
4 changes: 2 additions & 2 deletions networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,8 +1038,8 @@ def prepare_optimizer_params(
text_encoder_lr,
unet_lr,
default_lr,
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
Expand Down Expand Up @@ -1067,7 +1067,7 @@ def assemble_params(loras, lr, ratio):
else:
param_data["lr"] = lr

if ("lr" in param_data) and (param_data["lr"] == 0):
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue

params.append(param_data)
Expand Down

0 comments on commit 68467bd

Please sign in to comment.