Skip to content

Commit

Permalink
Min-SNR Weighting Strategy: Refactored and added to all trainers
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova committed Mar 22, 2023
1 parent 795a6bd commit 64c9232
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 14 deletions.
8 changes: 7 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
ConfigSanitizer,
BlueprintGenerator,
)

import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

def collate_fn(examples):
return examples[0]
Expand Down Expand Up @@ -304,6 +305,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")

if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
Expand Down Expand Up @@ -396,6 +400,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)


parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
Expand Down
17 changes: 17 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
import argparse

def apply_snr_weight(loss, latents, noisy_latents, gamma):
sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler
zeros = torch.zeros_like(sigma)
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment
snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
loss = loss * snr_weight
print(snr_weight)
return loss

def add_custom_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")
2 changes: 1 addition & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,7 +1963,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)
parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")


def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2:
Expand Down
7 changes: 6 additions & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
ConfigSanitizer,
BlueprintGenerator,
)

import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

def collate_fn(examples):
return examples[0]
Expand Down Expand Up @@ -291,6 +292,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

accelerator.backward(loss)
Expand Down Expand Up @@ -390,6 +394,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument(
"--no_token_padding",
Expand Down
17 changes: 6 additions & 11 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
ConfigSanitizer,
BlueprintGenerator,
)

import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

def collate_fn(examples):
return examples[0]
Expand Down Expand Up @@ -548,16 +549,9 @@ def train(args):

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
gamma = args.min_snr_gamma
if gamma:
sigma = torch.sub(noisy_latents, latents) #find noise as applied
zeros = torch.zeros_like(sigma)
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square
snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
loss = loss * snr_weight

if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down Expand Up @@ -662,6 +656,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument(
Expand Down
6 changes: 6 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

imagenet_templates_small = [
"a photo of a {}",
Expand Down Expand Up @@ -377,6 +379,9 @@ def train(args):

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
Expand Down Expand Up @@ -534,6 +539,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument(
"--save_model_as",
Expand Down

0 comments on commit 64c9232

Please sign in to comment.