Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption #1228

Merged
merged 14 commits into from
Apr 7, 2024
Merged
6 changes: 3 additions & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -368,7 +368,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand All @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down
79 changes: 75 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3236,6 +3236,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
)
parser.add_argument(
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_schedule",
type=str,
default="exponential",
choices=["constant", "exponential", "snr"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
)

parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -4842,6 +4862,38 @@ def save_sd_model_on_train_end_common(
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):

#TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way

if args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
)
timestep = timesteps.item()

if args.huber_schedule == "exponential":
alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!')

timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == 'l2':
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
else:
raise NotImplementedError(f'Unknown loss type {args.loss_type}')
timesteps = timesteps.long()

return timesteps, huber_c

def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
Expand All @@ -4862,8 +4914,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long()
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand All @@ -4876,8 +4927,28 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

return noise, noisy_latents, timesteps

return noise, noisy_latents, timesteps, huber_c

# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1):

if loss_type == 'l2':
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == 'huber':
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == 'smooth_l1':
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
return loss

def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names = []
Expand Down
6 changes: 3 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -600,7 +600,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand All @@ -616,7 +616,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -458,7 +458,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -426,7 +426,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down
11 changes: 3 additions & 8 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,8 @@ def remove_model(old_ckpt_name):
)

# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
Expand Down Expand Up @@ -457,7 +452,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down
4 changes: 2 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def train(args):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -358,7 +358,7 @@ def train(args):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down
4 changes: 2 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand Down Expand Up @@ -873,7 +873,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -588,7 +588,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -473,7 +473,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down