From 820d6e54a7fda74d8ac56fdc4e1b651815f8555e Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Fri, 10 Jan 2025 05:04:52 +0000 Subject: [PATCH] feat: ability to use no discriminator in supervised settings --- models/base_gan_model.py | 5 +- models/cut_model.py | 222 +++++++++++++++++++++----------------- options/common_options.py | 1 + tests/test_run_sr_gan.py | 2 +- 4 files changed, 130 insertions(+), 100 deletions(-) diff --git a/models/base_gan_model.py b/models/base_gan_model.py index 899328059..296085d61 100644 --- a/models/base_gan_model.py +++ b/models/base_gan_model.py @@ -138,7 +138,10 @@ def __init__(self, opt, rank): self.loss_names_G = losses_G self.loss_names_D = losses_D - self.loss_functions_G = ["compute_G_loss_GAN"] + if self.opt.D_netDs != ["none"]: + self.loss_functions_G = ["compute_G_loss_GAN"] + else: + self.loss_functions_G = [] self.forward_functions = ["forward_GAN"] if self.opt.train_semantic_mask: diff --git a/models/cut_model.py b/models/cut_model.py index d580bcd14..7f15affff 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -276,7 +276,10 @@ def __init__(self, opt, rank): if self.isTrain: # Discriminator(s) - self.netDs = gan_networks.define_D(**vars(opt)) + if self.opt.D_netDs != ["none"]: + self.netDs = gan_networks.define_D(**vars(opt)) + else: + self.netDs = {} self.discriminators_names = [ "D_B_" + D_name for D_name in self.netDs.keys() @@ -289,21 +292,22 @@ def __init__(self, opt, rank): # define loss functions self.criterionNCE = [] - for nce_layer in self.nce_layers: - if opt.alg_cut_nce_loss == "patchnce": - self.criterionNCE.append(PatchNCELoss(opt).to(self.device)) - elif opt.alg_cut_nce_loss == "monce": - self.criterionNCE.append(MoNCELoss(opt).to(self.device)) - elif opt.alg_cut_nce_loss == "SRC_hDCE": - self.criterionNCE.append(PatchHDCELoss(opt).to(self.device)) - - if opt.alg_cut_nce_loss == "SRC_hDCE": - self.criterionR = [] + if len(self.netDs): for nce_layer in self.nce_layers: - self.criterionR.append(SRC_Loss(opt).to(self.device)) + if opt.alg_cut_nce_loss == "patchnce": + self.criterionNCE.append(PatchNCELoss(opt).to(self.device)) + elif opt.alg_cut_nce_loss == "monce": + self.criterionNCE.append(MoNCELoss(opt).to(self.device)) + elif opt.alg_cut_nce_loss == "SRC_hDCE": + self.criterionNCE.append(PatchHDCELoss(opt).to(self.device)) - if self.opt.alg_cut_MSE_idt: - self.criterionIdt = torch.nn.L1Loss() + if opt.alg_cut_nce_loss == "SRC_hDCE": + self.criterionR = [] + for nce_layer in self.nce_layers: + self.criterionR.append(SRC_Loss(opt).to(self.device)) + + if self.opt.alg_cut_MSE_idt: + self.criterionIdt = torch.nn.L1Loss() if "MSE" in self.opt.alg_cut_supervised_loss: self.criterionSupervised = torch.nn.MSELoss() @@ -358,29 +362,30 @@ def __init__(self, opt, rank): eps=opt.train_optim_eps, ) - if len(self.discriminators_names) > 0: - D_parameters = itertools.chain( - *[ - getattr(self, "net" + D_name).parameters() - for D_name in self.discriminators_names - ] - ) - else: - D_parameters = getattr( - self, "net" + self.discriminators_names[0] - ).parameters() + if len(self.netDs): + if len(self.discriminators_names) > 0: + D_parameters = itertools.chain( + *[ + getattr(self, "net" + D_name).parameters() + for D_name in self.discriminators_names + ] + ) + else: + D_parameters = getattr( + self, "net" + self.discriminators_names[0] + ).parameters() - self.optimizer_D = opt.optim( - opt, - D_parameters, - lr=opt.train_D_lr, - betas=(opt.train_beta1, opt.train_beta2), - weight_decay=opt.train_optim_weight_decay, - eps=opt.train_optim_eps, - ) + self.optimizer_D = opt.optim( + opt, + D_parameters, + lr=opt.train_D_lr, + betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, + eps=opt.train_optim_eps, + ) + self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_G) - self.optimizers.append(self.optimizer_D) if self.opt.model_multimodal: self.optimizers.append(self.optimizer_E) @@ -389,9 +394,10 @@ def __init__(self, opt, rank): networks_to_optimize = ["G_A"] optimizers = ["optimizer_G"] - if self.opt.alg_cut_lambda_NCE > 0.0: - optimizers.append("optimizer_F") - networks_to_optimize.append("F") + if len(self.netDs): + if self.opt.alg_cut_lambda_NCE > 0.0: + optimizers.append("optimizer_F") + networks_to_optimize.append("F") losses_backward = ["loss_G_tot"] if self.opt.model_multimodal: @@ -419,47 +425,51 @@ def __init__(self, opt, rank): ) self.networks_groups.append(self.group_E) - self.group_D = NetworkGroup( - networks_to_optimize=self.discriminators_names, - forward_functions=None, - backward_functions=["compute_D_loss"], - loss_names_list=["loss_names_D"], - optimizer=["optimizer_D"], - loss_backward=["loss_D_tot"], - ) - self.networks_groups.append(self.group_D) - - # Discriminators + if len(self.netDs): + self.group_D = NetworkGroup( + networks_to_optimize=self.discriminators_names, + forward_functions=None, + backward_functions=["compute_D_loss"], + loss_names_list=["loss_names_D"], + optimizer=["optimizer_D"], + loss_backward=["loss_D_tot"], + ) + self.networks_groups.append(self.group_D) - self.set_discriminators_info() + # Discriminators + self.set_discriminators_info() + else: + self.discriminators = [] # Losses names - losses_G = [] - if opt.alg_cut_lambda_NCE > 0.0: - losses_G += ["G_NCE"] - if opt.alg_cut_supervised_loss != [""]: - losses_G += ["G_supervised"] - losses_D = [] - if opt.alg_cut_nce_idt and self.isTrain: - losses_G += ["G_NCE_Y"] - - if opt.alg_cut_MSE_idt: - losses_G += ["G_MSE_idt"] - - if opt.model_multimodal and self.isTrain: - losses_E = ["G_z"] - losses_G += ["G_z"] - - if self.isTrain: + if opt.isTrain: + losses_G = [] + if opt.alg_cut_lambda_NCE > 0.0 and len(self.netDs): + losses_G += ["G_NCE"] + if opt.alg_cut_supervised_loss != [""]: + losses_G += ["G_supervised"] + if opt.alg_cut_nce_idt and self.isTrain and len(self.netDs): + losses_G += ["G_NCE_Y"] + + if opt.alg_cut_MSE_idt and len(self.netDs): + losses_G += ["G_MSE_idt"] + + if opt.model_multimodal and self.isTrain: + losses_E = ["G_z"] + losses_G += ["G_z"] + + if self.isTrain and len(self.netDs): + losses_D = [] for discriminator in self.discriminators: losses_D.append(discriminator.loss_name_D) if "mask" in discriminator.name: continue else: losses_G.append(discriminator.loss_name_G) + self.loss_names_D += losses_D - self.loss_names_G += losses_G - self.loss_names_D += losses_D + if self.isTrain: + self.loss_names_G += losses_G if self.opt.model_multimodal: self.loss_names_E = losses_E self.loss_names_G += losses_E @@ -712,9 +722,14 @@ def compute_G_loss_cut(self): fake_B_nc = self.fake_B for c in range(diffc): fake_B_nc = torch.cat((fake_B_nc, add1), dim=1) - feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, fake_B_nc) + + if len(self.netDs): + feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, fake_B_nc) else: - feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, self.fake_B) + if len(self.netDs): + feat_q_pool, feat_k_pool = self.calculate_feats( + self.real_A, self.fake_B + ) if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE": self.loss_G_SRC, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool) @@ -722,34 +737,43 @@ def compute_G_loss_cut(self): self.loss_G_SRC = 0.0 weight = None - if self.opt.alg_cut_lambda_NCE > 0.0: - self.loss_G_NCE = self.calculate_NCE_loss(feat_q_pool, feat_k_pool, weight) - else: - self.loss_G_NCE = 0.0 + if len(self.netDs): + if self.opt.alg_cut_lambda_NCE > 0.0: + self.loss_G_NCE = self.calculate_NCE_loss( + feat_q_pool, feat_k_pool, weight + ) + else: + self.loss_G_NCE = 0.0 - # Identity losses - if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_SRC > 0.0: - feat_q_pool, feat_k_pool = self.calculate_feats(self.real_B, self.idt_B) - if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE": - self.loss_G_SRC_Y, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool) - else: - self.loss_G_SRC = 0.0 - weight = None + # Identity losses + if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_SRC > 0.0: + feat_q_pool, feat_k_pool = self.calculate_feats(self.real_B, self.idt_B) + if ( + self.opt.alg_cut_lambda_SRC > 0.0 + or self.opt.alg_cut_nce_loss == "SRC_hDCE" + ): + self.loss_G_SRC_Y, weight = self.calculate_R_loss( + feat_q_pool, feat_k_pool + ) + else: + self.loss_G_SRC = 0.0 + weight = None - if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_NCE > 0.0: - self.loss_G_NCE_Y = self.calculate_NCE_loss( - feat_q_pool, feat_k_pool, weight - ) - loss_NCE_both = (self.loss_G_NCE + self.loss_G_NCE_Y) * 0.5 - else: - loss_NCE_both = self.loss_G_NCE + if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_NCE > 0.0: + self.loss_G_NCE_Y = self.calculate_NCE_loss( + feat_q_pool, feat_k_pool, weight + ) + loss_NCE_both = (self.loss_G_NCE + self.loss_G_NCE_Y) * 0.5 + else: + loss_NCE_both = self.loss_G_NCE - if self.opt.alg_cut_MSE_idt and self.opt.alg_cut_lambda_MSE_idt > 0.0: - self.loss_G_MSE_idt = self.opt.alg_cut_lambda_MSE_idt * self.criterionIdt( - self.real_B, self.idt_B - ) - else: - self.loss_G_MSE_idt = 0 + if self.opt.alg_cut_MSE_idt and self.opt.alg_cut_lambda_MSE_idt > 0.0: + self.loss_G_MSE_idt = ( + self.opt.alg_cut_lambda_MSE_idt + * self.criterionIdt(self.real_B, self.idt_B) + ) + else: + self.loss_G_MSE_idt = 0 # supervised loss with aligned data if ( @@ -792,7 +816,8 @@ def compute_G_loss_cut(self): else: self.loss_G_supervised_dists = 0 - self.loss_G_tot += loss_NCE_both + self.loss_G_MSE_idt + if len(self.netDs): + self.loss_G_tot += loss_NCE_both + self.loss_G_MSE_idt if ( self.loss_G_supervised_norm > 0 @@ -806,8 +831,9 @@ def compute_G_loss_cut(self): ) self.loss_G_tot += self.loss_G_supervised - self.compute_E_loss() - self.loss_G_tot += self.loss_G_z + if len(self.netDs): + self.compute_E_loss() + self.loss_G_tot += self.loss_G_z def compute_E_loss(self): # multimodal loss diff --git a/options/common_options.py b/options/common_options.py index 205486055..66cbb1e9b 100644 --- a/options/common_options.py +++ b/options/common_options.py @@ -450,6 +450,7 @@ def initialize(self, parser): "depth", "mask", "sam", + "none", ] + list(TORCH_MODEL_CLASSES.keys()), help="specify discriminator architecture, another option, --D_n_layers allows you to specify the layers in the n_layers discriminator. NB: duplicated arguments are ignored. Values: basic, n_layers, pixel, projected_d, temporal, vision_aided, depth, mask, sam", diff --git a/tests/test_run_sr_gan.py b/tests/test_run_sr_gan.py index 618ae0221..ccb900751 100644 --- a/tests/test_run_sr_gan.py +++ b/tests/test_run_sr_gan.py @@ -34,7 +34,7 @@ ] G_netGs = ["unet_mha", "hdit", "hat"] -D_netDs = [["basic"]] +D_netDs = [["basic"], ["none"]] product_list = product(models, G_netGs, D_netDs)