Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ability to use no discriminator in supervised settings #739

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ def __init__(self, opt, rank):
self.loss_names_G = losses_G
self.loss_names_D = losses_D

self.loss_functions_G = ["compute_G_loss_GAN"]
if self.opt.D_netDs != ["none"]:
self.loss_functions_G = ["compute_G_loss_GAN"]
else:
self.loss_functions_G = []
self.forward_functions = ["forward_GAN"]

if self.opt.train_semantic_mask:
Expand Down
222 changes: 124 additions & 98 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ def __init__(self, opt, rank):

if self.isTrain:
# Discriminator(s)
self.netDs = gan_networks.define_D(**vars(opt))
if self.opt.D_netDs != ["none"]:
self.netDs = gan_networks.define_D(**vars(opt))
else:
self.netDs = {}

self.discriminators_names = [
"D_B_" + D_name for D_name in self.netDs.keys()
Expand All @@ -289,21 +292,22 @@ def __init__(self, opt, rank):
# define loss functions
self.criterionNCE = []

for nce_layer in self.nce_layers:
if opt.alg_cut_nce_loss == "patchnce":
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
elif opt.alg_cut_nce_loss == "monce":
self.criterionNCE.append(MoNCELoss(opt).to(self.device))
elif opt.alg_cut_nce_loss == "SRC_hDCE":
self.criterionNCE.append(PatchHDCELoss(opt).to(self.device))

if opt.alg_cut_nce_loss == "SRC_hDCE":
self.criterionR = []
if len(self.netDs):
for nce_layer in self.nce_layers:
self.criterionR.append(SRC_Loss(opt).to(self.device))
if opt.alg_cut_nce_loss == "patchnce":
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
elif opt.alg_cut_nce_loss == "monce":
self.criterionNCE.append(MoNCELoss(opt).to(self.device))
elif opt.alg_cut_nce_loss == "SRC_hDCE":
self.criterionNCE.append(PatchHDCELoss(opt).to(self.device))

if self.opt.alg_cut_MSE_idt:
self.criterionIdt = torch.nn.L1Loss()
if opt.alg_cut_nce_loss == "SRC_hDCE":
self.criterionR = []
for nce_layer in self.nce_layers:
self.criterionR.append(SRC_Loss(opt).to(self.device))

if self.opt.alg_cut_MSE_idt:
self.criterionIdt = torch.nn.L1Loss()

if "MSE" in self.opt.alg_cut_supervised_loss:
self.criterionSupervised = torch.nn.MSELoss()
Expand Down Expand Up @@ -358,29 +362,30 @@ def __init__(self, opt, rank):
eps=opt.train_optim_eps,
)

if len(self.discriminators_names) > 0:
D_parameters = itertools.chain(
*[
getattr(self, "net" + D_name).parameters()
for D_name in self.discriminators_names
]
)
else:
D_parameters = getattr(
self, "net" + self.discriminators_names[0]
).parameters()
if len(self.netDs):
if len(self.discriminators_names) > 0:
D_parameters = itertools.chain(
*[
getattr(self, "net" + D_name).parameters()
for D_name in self.discriminators_names
]
)
else:
D_parameters = getattr(
self, "net" + self.discriminators_names[0]
).parameters()

self.optimizer_D = opt.optim(
opt,
D_parameters,
lr=opt.train_D_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)
self.optimizer_D = opt.optim(
opt,
D_parameters,
lr=opt.train_D_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)
self.optimizers.append(self.optimizer_D)

self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
if self.opt.model_multimodal:
self.optimizers.append(self.optimizer_E)

Expand All @@ -389,9 +394,10 @@ def __init__(self, opt, rank):

networks_to_optimize = ["G_A"]
optimizers = ["optimizer_G"]
if self.opt.alg_cut_lambda_NCE > 0.0:
optimizers.append("optimizer_F")
networks_to_optimize.append("F")
if len(self.netDs):
if self.opt.alg_cut_lambda_NCE > 0.0:
optimizers.append("optimizer_F")
networks_to_optimize.append("F")
losses_backward = ["loss_G_tot"]

if self.opt.model_multimodal:
Expand Down Expand Up @@ -419,47 +425,51 @@ def __init__(self, opt, rank):
)
self.networks_groups.append(self.group_E)

self.group_D = NetworkGroup(
networks_to_optimize=self.discriminators_names,
forward_functions=None,
backward_functions=["compute_D_loss"],
loss_names_list=["loss_names_D"],
optimizer=["optimizer_D"],
loss_backward=["loss_D_tot"],
)
self.networks_groups.append(self.group_D)

# Discriminators
if len(self.netDs):
self.group_D = NetworkGroup(
networks_to_optimize=self.discriminators_names,
forward_functions=None,
backward_functions=["compute_D_loss"],
loss_names_list=["loss_names_D"],
optimizer=["optimizer_D"],
loss_backward=["loss_D_tot"],
)
self.networks_groups.append(self.group_D)

self.set_discriminators_info()
# Discriminators
self.set_discriminators_info()
else:
self.discriminators = []

# Losses names
losses_G = []
if opt.alg_cut_lambda_NCE > 0.0:
losses_G += ["G_NCE"]
if opt.alg_cut_supervised_loss != [""]:
losses_G += ["G_supervised"]
losses_D = []
if opt.alg_cut_nce_idt and self.isTrain:
losses_G += ["G_NCE_Y"]

if opt.alg_cut_MSE_idt:
losses_G += ["G_MSE_idt"]

if opt.model_multimodal and self.isTrain:
losses_E = ["G_z"]
losses_G += ["G_z"]

if self.isTrain:
if opt.isTrain:
losses_G = []
if opt.alg_cut_lambda_NCE > 0.0 and len(self.netDs):
losses_G += ["G_NCE"]
if opt.alg_cut_supervised_loss != [""]:
losses_G += ["G_supervised"]
if opt.alg_cut_nce_idt and self.isTrain and len(self.netDs):
losses_G += ["G_NCE_Y"]

if opt.alg_cut_MSE_idt and len(self.netDs):
losses_G += ["G_MSE_idt"]

if opt.model_multimodal and self.isTrain:
losses_E = ["G_z"]
losses_G += ["G_z"]

if self.isTrain and len(self.netDs):
losses_D = []
for discriminator in self.discriminators:
losses_D.append(discriminator.loss_name_D)
if "mask" in discriminator.name:
continue
else:
losses_G.append(discriminator.loss_name_G)
self.loss_names_D += losses_D

self.loss_names_G += losses_G
self.loss_names_D += losses_D
if self.isTrain:
self.loss_names_G += losses_G
if self.opt.model_multimodal:
self.loss_names_E = losses_E
self.loss_names_G += losses_E
Expand Down Expand Up @@ -712,44 +722,58 @@ def compute_G_loss_cut(self):
fake_B_nc = self.fake_B
for c in range(diffc):
fake_B_nc = torch.cat((fake_B_nc, add1), dim=1)
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, fake_B_nc)

if len(self.netDs):
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, fake_B_nc)
else:
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, self.fake_B)
if len(self.netDs):
feat_q_pool, feat_k_pool = self.calculate_feats(
self.real_A, self.fake_B
)

if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE":
self.loss_G_SRC, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool)
else:
self.loss_G_SRC = 0.0
weight = None

if self.opt.alg_cut_lambda_NCE > 0.0:
self.loss_G_NCE = self.calculate_NCE_loss(feat_q_pool, feat_k_pool, weight)
else:
self.loss_G_NCE = 0.0
if len(self.netDs):
if self.opt.alg_cut_lambda_NCE > 0.0:
self.loss_G_NCE = self.calculate_NCE_loss(
feat_q_pool, feat_k_pool, weight
)
else:
self.loss_G_NCE = 0.0

# Identity losses
if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_SRC > 0.0:
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_B, self.idt_B)
if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE":
self.loss_G_SRC_Y, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool)
else:
self.loss_G_SRC = 0.0
weight = None
# Identity losses
if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_SRC > 0.0:
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_B, self.idt_B)
if (
self.opt.alg_cut_lambda_SRC > 0.0
or self.opt.alg_cut_nce_loss == "SRC_hDCE"
):
self.loss_G_SRC_Y, weight = self.calculate_R_loss(
feat_q_pool, feat_k_pool
)
else:
self.loss_G_SRC = 0.0
weight = None

if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_NCE > 0.0:
self.loss_G_NCE_Y = self.calculate_NCE_loss(
feat_q_pool, feat_k_pool, weight
)
loss_NCE_both = (self.loss_G_NCE + self.loss_G_NCE_Y) * 0.5
else:
loss_NCE_both = self.loss_G_NCE
if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_NCE > 0.0:
self.loss_G_NCE_Y = self.calculate_NCE_loss(
feat_q_pool, feat_k_pool, weight
)
loss_NCE_both = (self.loss_G_NCE + self.loss_G_NCE_Y) * 0.5
else:
loss_NCE_both = self.loss_G_NCE

if self.opt.alg_cut_MSE_idt and self.opt.alg_cut_lambda_MSE_idt > 0.0:
self.loss_G_MSE_idt = self.opt.alg_cut_lambda_MSE_idt * self.criterionIdt(
self.real_B, self.idt_B
)
else:
self.loss_G_MSE_idt = 0
if self.opt.alg_cut_MSE_idt and self.opt.alg_cut_lambda_MSE_idt > 0.0:
self.loss_G_MSE_idt = (
self.opt.alg_cut_lambda_MSE_idt
* self.criterionIdt(self.real_B, self.idt_B)
)
else:
self.loss_G_MSE_idt = 0

# supervised loss with aligned data
if (
Expand Down Expand Up @@ -792,7 +816,8 @@ def compute_G_loss_cut(self):
else:
self.loss_G_supervised_dists = 0

self.loss_G_tot += loss_NCE_both + self.loss_G_MSE_idt
if len(self.netDs):
self.loss_G_tot += loss_NCE_both + self.loss_G_MSE_idt

if (
self.loss_G_supervised_norm > 0
Expand All @@ -806,8 +831,9 @@ def compute_G_loss_cut(self):
)
self.loss_G_tot += self.loss_G_supervised

self.compute_E_loss()
self.loss_G_tot += self.loss_G_z
if len(self.netDs):
self.compute_E_loss()
self.loss_G_tot += self.loss_G_z

def compute_E_loss(self):
# multimodal loss
Expand Down
1 change: 1 addition & 0 deletions options/common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def initialize(self, parser):
"depth",
"mask",
"sam",
"none",
]
+ list(TORCH_MODEL_CLASSES.keys()),
help="specify discriminator architecture, another option, --D_n_layers allows you to specify the layers in the n_layers discriminator. NB: duplicated arguments are ignored. Values: basic, n_layers, pixel, projected_d, temporal, vision_aided, depth, mask, sam",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_run_sr_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
]

G_netGs = ["unet_mha", "hdit", "hat"]
D_netDs = [["basic"]]
D_netDs = [["basic"], ["none"]]

product_list = product(models, G_netGs, D_netDs)

Expand Down