Skip to content

Commit

Permalink
Merge pull request #948 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Sep 6, 2024
2 parents 38f84f8 + d1e25e0 commit ac17183
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
7 changes: 7 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ def parse_cmdline_args(input_args=None):
default="config/lycoris_config.json",
help=("The location for the JSON file of the Lycoris configuration."),
)
parser.add_argument(
"--init_lokr_norm",
type=float,
required=False,
default=None,
help=("Setting this turns on perturbed normal initialization of the LyCORIS LoKr PEFT layers. A good value is between 1e-4 and 1e-2."),
)
parser.add_argument(
"--controlnet",
action="store_true",
Expand Down
24 changes: 24 additions & 0 deletions helpers/training/peft_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch


def approximate_normal_tensor(inp, target, scale=1.0):
tensor = torch.randn_like(target)
desired_norm = inp.norm()
desired_mean = inp.mean()
desired_std = inp.std()

current_norm = tensor.norm()
tensor = tensor * (desired_norm / current_norm)
current_std = tensor.std()
tensor = tensor * (desired_std / current_std)
tensor = tensor - tensor.mean() + desired_mean
tensor.mul_(scale)

target.copy_(tensor)


def init_lokr_network_with_perturbed_normal(lycoris, scale=1e-3):
with torch.no_grad():
for lora in lycoris.loras:
lora.lokr_w1.fill_(1.0)
approximate_normal_tensor(lora.org_weight, lora.lokr_w2, scale=scale)
8 changes: 8 additions & 0 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
segmented_timestep_selection,
)
from helpers.training.min_snr_gamma import compute_snr
from helpers.training.peft_init import init_lokr_network_with_perturbed_standard
from accelerate.logging import get_logger
from diffusers.models.embeddings import get_2d_rotary_pos_embed
from helpers.models.smoldit import get_resize_crop_region_for_grid
Expand Down Expand Up @@ -794,6 +795,12 @@ def init_trainable_peft_adapter(self):
**self.lycoris_config,
)

if self.config.init_lycoris_lokr_perturbed_normal is not None:
init_lokr_network_with_perturbed_standard(
self.lycoris_wrapped_network,
scale=self.config.init_lokr_norm,
)

self.lycoris_wrapped_network.apply_to()
setattr(
self.accelerator,
Expand Down Expand Up @@ -1319,6 +1326,7 @@ def init_trackers(self):
delattr(public_args, "process_group_kwargs")
delattr(public_args, "weight_dtype")
delattr(public_args, "base_weight_dtype")
delattr(public_args, "vae_kwargs")

# Hash the contents of public_args to reflect a deterministic ID for a single set of params:
public_args_hash = hashlib.md5(
Expand Down
10 changes: 9 additions & 1 deletion helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,15 @@ def _log_validations_to_webhook(

def _log_validations_to_trackers(self, validation_images):
for tracker in self.accelerator.trackers:
if tracker.name == "wandb":
if tracker.name == "comet_ml":
experiment = self.accelerator.get_tracker("comet_ml").tracker
for shortname, image_list in validation_images.items():
for idx, image in enumerate(image_list):
experiment.log_image(
image,
name=f"{shortname} - {self.validation_resolutions[idx]}",
)
elif tracker.name == "wandb":
resolution_list = [
f"{res[0]}x{res[1]}" for res in get_validation_resolutions()
]
Expand Down

0 comments on commit ac17183

Please sign in to comment.