diff --git a/fine_tune.py b/fine_tune.py index 52e84c43f..b07876776 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -355,7 +355,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 28b625d30..e0a026dae 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -57,10 +57,13 @@ def enforce_zero_terminal_snr(betas): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + if v_prediction: + snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) + else: + snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight return loss diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 54abf697c..44447d1f0 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -460,7 +460,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f00f10eaa..91cbacc6a 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -430,7 +430,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb3..e0118d1c5 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -449,7 +449,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 7fbbc18ac..966999dfb 100644 --- a/train_db.py +++ b/train_db.py @@ -342,7 +342,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/train_network.py b/train_network.py index d50916b74..1cbed2e7b 100644 --- a/train_network.py +++ b/train_network.py @@ -812,7 +812,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a0..45a437b91 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -578,7 +578,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8dd5c672f..f77ad2eb2 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -469,7 +469,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: