From 45b64ecf41398efdcc650d7bc29ca4e97de4430f Mon Sep 17 00:00:00 2001 From: Raphael Walker Date: Thu, 13 Jun 2024 12:12:33 +0200 Subject: [PATCH 01/12] Add lognorm and cosmap weighting --- examples/dreambooth/train_dreambooth_lora_sd3.py | 15 +++++++++++---- examples/dreambooth/train_dreambooth_sd3.py | 16 ++++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 5e8bc7bab818..1c6c55278699 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1487,12 +1487,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): weighting = (sigmas**-2.0).float() elif args.weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) - weighting = torch.nn.functional.sigmoid(u) + # A better approach is just to sample the timestamps non-uniformly. + m = args.logit_mean + s = args.logit_std + weighting = torch.exp(-(torch.logit(sigmas) - m)**2 / (2 * s**2)) + weighting = weighting / (sigmas * (1 - sigmas) * s * math.sqrt(2 * math.pi)) elif args.weighting_scheme == "mode": # See sec 3.1 in the SD3 paper (20). - u = torch.rand(size=(bsz,), device=accelerator.device) - weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + raise NotImplementedError("Mode weighting scheme is not implemented.") + elif args.weighting_scheme == "cosmap": + bot = (1 - 2*sigmas + 2*sigmas**2) + weighting = 2/(math.pi*bot) + else: + weighting = torch.ones_like(sigmas) # simplified flow matching aka 0-rectified flow matching loss # target = model_input - noise diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index adcea652db74..fc60d2ac8e6d 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1566,12 +1566,20 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): weighting = (sigmas**-2.0).float() elif args.weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) - weighting = torch.nn.functional.sigmoid(u) + # A better approach is just to sample the timestamps non-uniformly. + m = args.logit_mean + s = args.logit_std + weighting = torch.exp(-(torch.logit(sigmas) - m)**2 / (2 * s**2)) + weighting = weighting / (sigmas * (1 - sigmas) * s * math.sqrt(2 * math.pi)) elif args.weighting_scheme == "mode": # See sec 3.1 in the SD3 paper (20). - u = torch.rand(size=(bsz,), device=accelerator.device) - weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + raise NotImplementedError("Mode weighting scheme is not implemented.") + + elif args.weighting_scheme == "cosmap": + bot = (1 - 2*sigmas + 2*sigmas**2) + weighting = 2/(math.pi*bot) + else: + weighting = torch.ones_like(sigmas) # simplified flow matching aka 0-rectified flow matching loss # target = model_input - noise From 5c3f7552f4d26127c1420f315a32518803fecc0c Mon Sep 17 00:00:00 2001 From: Raphael Walker Date: Thu, 13 Jun 2024 13:20:30 +0200 Subject: [PATCH 02/12] Implement mode sampling --- .../dreambooth/train_dreambooth_lora_sd3.py | 26 +++++++++------- examples/dreambooth/train_dreambooth_sd3.py | 30 ++++++++++--------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 1c6c55278699..3fc549f55c82 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1462,7 +1462,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): bsz = model_input.shape[0] # Sample a random timestep for each image - indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,)) + # for weighting schemes where we sample timesteps non-uniformly + if args.weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) + u = torch.nn.functional.sigmoid(u) + elif args.weighting_scheme == "mode": + u = torch.rand(size=(bsz,), device=accelerator.device) + u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(bsz,), device=accelerator.device) + + + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. @@ -1483,18 +1495,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred = model_pred * (-sigmas) + noisy_model_input # TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :) + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss if args.weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() - elif args.weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - # A better approach is just to sample the timestamps non-uniformly. - m = args.logit_mean - s = args.logit_std - weighting = torch.exp(-(torch.logit(sigmas) - m)**2 / (2 * s**2)) - weighting = weighting / (sigmas * (1 - sigmas) * s * math.sqrt(2 * math.pi)) - elif args.weighting_scheme == "mode": - # See sec 3.1 in the SD3 paper (20). - raise NotImplementedError("Mode weighting scheme is not implemented.") elif args.weighting_scheme == "cosmap": bot = (1 - 2*sigmas + 2*sigmas**2) weighting = 2/(math.pi*bot) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index fc60d2ac8e6d..e94c3c93bb92 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1526,7 +1526,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): bsz = model_input.shape[0] # Sample a random timestep for each image - indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,)) + # for weighting schemes where we sample timesteps non-uniformly + if args.weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) + u = torch.nn.functional.sigmoid(u) + elif args.weighting_scheme == "mode": + u = torch.rand(size=(bsz,), device=accelerator.device) + u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(bsz,), device=accelerator.device) + + + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. @@ -1560,21 +1572,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. model_pred = model_pred * (-sigmas) + noisy_model_input - - # TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss if args.weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() - elif args.weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - # A better approach is just to sample the timestamps non-uniformly. - m = args.logit_mean - s = args.logit_std - weighting = torch.exp(-(torch.logit(sigmas) - m)**2 / (2 * s**2)) - weighting = weighting / (sigmas * (1 - sigmas) * s * math.sqrt(2 * math.pi)) - elif args.weighting_scheme == "mode": - # See sec 3.1 in the SD3 paper (20). - raise NotImplementedError("Mode weighting scheme is not implemented.") - elif args.weighting_scheme == "cosmap": bot = (1 - 2*sigmas + 2*sigmas**2) weighting = 2/(math.pi*bot) From 77305e39813196b3098d20dd0fbf6f074c66ac9a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 18:30:10 +0200 Subject: [PATCH 03/12] Update examples/dreambooth/train_dreambooth_lora_sd3.py --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 3fc549f55c82..198ad0123d8e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1500,7 +1500,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() elif args.weighting_scheme == "cosmap": - bot = (1 - 2*sigmas + 2*sigmas**2) + bot = 1 - 2 * sigmas + 2 * sigmas**2 weighting = 2/(math.pi*bot) else: weighting = torch.ones_like(sigmas) From 41803fd6854efa45f5380f691a4ed73fcbb2b179 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 18:30:16 +0200 Subject: [PATCH 04/12] Update examples/dreambooth/train_dreambooth_lora_sd3.py --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 198ad0123d8e..23ad4a9c1e30 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1501,7 +1501,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): weighting = (sigmas**-2.0).float() elif args.weighting_scheme == "cosmap": bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2/(math.pi*bot) + weighting = 2 / (math.pi * bot) else: weighting = torch.ones_like(sigmas) From 3a428befb21fdf2d7654c96cfeba4515b2464320 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 18:30:24 +0200 Subject: [PATCH 05/12] Update examples/dreambooth/train_dreambooth_sd3.py --- examples/dreambooth/train_dreambooth_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index e94c3c93bb92..763e0ee30110 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1572,7 +1572,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. model_pred = model_pred * (-sigmas) + noisy_model_input - # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss if args.weighting_scheme == "sigma_sqrt": From 94735893564696946548b867cc2f22586ce7907c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 18:30:30 +0200 Subject: [PATCH 06/12] Update examples/dreambooth/train_dreambooth_sd3.py --- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 763e0ee30110..2d6224b7dcbd 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1577,7 +1577,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() elif args.weighting_scheme == "cosmap": - bot = (1 - 2*sigmas + 2*sigmas**2) + bot = 1 - 2 * sigmas + 2 * sigmas**2 weighting = 2/(math.pi*bot) else: weighting = torch.ones_like(sigmas) From 6e231393e94f302c723aaf0f38719d715c84cef2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 18:30:35 +0200 Subject: [PATCH 07/12] Update examples/dreambooth/train_dreambooth_sd3.py --- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 2d6224b7dcbd..d3c0aa45a7c7 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1578,7 +1578,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): weighting = (sigmas**-2.0).float() elif args.weighting_scheme == "cosmap": bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2/(math.pi*bot) + weighting = 2 / (math.pi * bot) else: weighting = torch.ones_like(sigmas) From 994da3d9da85cfd1adebbee232619c656a6524bf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 18:42:21 +0200 Subject: [PATCH 08/12] Update examples/dreambooth/train_dreambooth_lora_sd3.py --- examples/dreambooth/train_dreambooth_lora_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 23ad4a9c1e30..293eccafbf21 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1473,7 +1473,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: u = torch.rand(size=(bsz,), device=accelerator.device) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) From f360c2d8a533579cc69983648998b7458220a435 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 18:48:58 +0200 Subject: [PATCH 09/12] Update examples/dreambooth/train_dreambooth_sd3.py --- examples/dreambooth/train_dreambooth_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index d3c0aa45a7c7..c4298d56ad34 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1537,7 +1537,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: u = torch.rand(size=(bsz,), device=accelerator.device) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) From 0c09c91313666aec44750831d865d5ac8927ed95 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 21:16:58 +0200 Subject: [PATCH 10/12] Update examples/dreambooth/train_dreambooth_sd3.py --- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 6bc76eac50b8..b287d7e489dc 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1537,7 +1537,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: u = torch.rand(size=(bsz,), device=accelerator.device) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long().cpu() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. From 9feda1c59ac6219f2c8fda735f829aa8f042d2c0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Jun 2024 21:17:06 +0200 Subject: [PATCH 11/12] Update examples/dreambooth/train_dreambooth_lora_sd3.py --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2af71a493431..b059df89af23 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1473,7 +1473,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: u = torch.rand(size=(bsz,), device=accelerator.device) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long().cpu() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. From 731872984c7f23fca5da8c99ac64b8782c307d84 Mon Sep 17 00:00:00 2001 From: Raphael Walker Date: Fri, 14 Jun 2024 11:55:55 +0200 Subject: [PATCH 12/12] keep timestamp sampling fully on cpu --- examples/dreambooth/train_dreambooth_lora_sd3.py | 8 ++++---- examples/dreambooth/train_dreambooth_sd3.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index b059df89af23..67227e2defc6 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1465,15 +1465,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # for weighting schemes where we sample timesteps non-uniformly if args.weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) + u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu") u = torch.nn.functional.sigmoid(u) elif args.weighting_scheme == "mode": - u = torch.rand(size=(bsz,), device=accelerator.device) + u = torch.rand(size=(bsz,), device="cpu") u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) else: - u = torch.rand(size=(bsz,), device=accelerator.device) + u = torch.rand(size=(bsz,), device="cpu") - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long().cpu() + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index b287d7e489dc..7920b4c8e0fa 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1529,15 +1529,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # for weighting schemes where we sample timesteps non-uniformly if args.weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) + u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu") u = torch.nn.functional.sigmoid(u) elif args.weighting_scheme == "mode": - u = torch.rand(size=(bsz,), device=accelerator.device) + u = torch.rand(size=(bsz,), device="cpu") u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) else: - u = torch.rand(size=(bsz,), device=accelerator.device) + u = torch.rand(size=(bsz,), device="cpu") - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long().cpu() + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching.