Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ml): add lambda #661

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,15 @@ Trains a consistency model to insert glasses onto faces.
.. code:: bash

python3 train.py --dataroot /path/to/data/noglasses2glasses_ffhq --checkpoints_dir /path/to/checkpoints --name noglasses2glasses --config_json examples/example_cm_noglasses2glasses.json

*************************************************************
Adversarial Consistency Model training for object insertion / inpainting
*************************************************************

Dataset: https://joligen.com/datasets/noglasses2glasses_ffhq.zip

Trains an adversarial consistency model to insert glasses onto faces, achieving faster convergence

.. code:: bash

python3 train.py --dataroot /path/to/data/noglasses2glasses_ffhq --checkpoints_dir /path/to/checkpoints --name noglasses2glasses --config_json examples/example_cm_noglasses2glasses.json
31 changes: 21 additions & 10 deletions models/cm_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, opt, rank):
visual_names_B = ["real_B"]
self.visual_names.append(visual_names_A)
self.visual_names.append(visual_names_B)

self.lambda1 = 0.6
self.lambda2 = 1.6
if self.isTrain:
# Discriminator(s)
self.netDs = gan_networks.define_D(**vars(opt))
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(self, opt, rank):
)
self.networks_groups.append(self.group_D)
self.set_discriminators_info()
losses_GAN_lambda = ["GAN_lambda_function"]
losses_D = []
losses_G = ["G_cm"]
for discriminator in self.discriminators:
Expand All @@ -81,10 +83,9 @@ def __init__(self, opt, rank):
continue
else:
losses_G.append(discriminator.loss_name_G)

self.loss_names_D += losses_D
self.loss_names_G += losses_G
self.loss_names = self.loss_names_G + self.loss_names_D
self.loss_names = self.loss_names_G + self.loss_names_D + losses_GAN_lambda

# Itercalculator
self.iter_calculator_init()
Expand All @@ -94,13 +95,23 @@ def compute_G_loss(self):
with torch.cuda.amp.autocast(enabled=self.with_amp):
getattr(self, loss_function)()

def compute_cm_gan_loss(self): ##TODO: replace compute_cm_loss in backward
def lambda_function(self, n, N):
return self.lambda1 * (n / (N - 1)) ** self.lambda2

def compute_cm_gan_loss(self):
self.compute_cm_loss()
self.loss_G_cm = self.loss_G_tot.clone().detach()
# print("loss_G_tot cm: ", self.loss_G_tot)
# print("self.loss_G_cm_tot: ", self.loss_G_cm_tot)
self.loss_G_cm = self.loss_G_tot
self.fake_B = self.pred_x
self.compute_G_loss()
# print("self.loss_G_cm_gan_tot: ", self.loss_G_cm_gan_tot)
# return self.loss_G_cm_tot + self.loss_G_cm_gan_tot
# print("loss_G_tot: ", self.loss_G_tot)
self.loss_G_cm_gan_tot = self.loss_G_tot
lambda_gan = self.lambda_function(
self.opt.total_iters, self.opt.alg_cm_num_steps
)
self.loss_GAN_lambda_function = lambda_gan
self.compute_D_loss()
loss_cm_gan_tot = (
self.loss_G_cm * (1 - lambda_gan)
+ (self.loss_G_cm_gan_tot - self.loss_G_cm + self.loss_D_tot) * lambda_gan
)

return loss_cm_gan_tot
Loading