From 706356b70cb273575eb4bdba80ffdbbe1798e5a0 Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Mon, 30 Sep 2024 14:18:47 +0000 Subject: [PATCH] fix: cm at test time --- models/cm_model.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/models/cm_model.py b/models/cm_model.py index 4f27fc938..9b4758eaf 100644 --- a/models/cm_model.py +++ b/models/cm_model.py @@ -133,7 +133,10 @@ def __init__(self, opt, rank): opt.alg_diffusion_cond_embed = opt.alg_diffusion_cond_image_creation opt.alg_diffusion_cond_embed_dim = 256 self.netG_A = diffusion_networks.define_G(**vars(opt)).to(self.device) - self.netG_A.current_t = max(self.netG_A.current_t, opt.total_iters) + if opt.isTrain: + self.netG_A.current_t = max(self.netG_A.current_t, opt.total_iters) + else: + self.netG_A.current_t = 0 # placeholder print("Setting CM current_iter to", self.netG_A.current_t) self.model_names = ["G_A"] @@ -329,14 +332,15 @@ def inference(self, nb_imgs, offset=0): self.visuals = self.output # set visual names - for name in self.gen_visual_names: - whole_tensor = getattr(self, name[:-1]) - for k in range(min(nb_imgs, self.get_current_batch_size())): - cur_name = name + str(offset + k) - cur_tensor = whole_tensor[k : k + 1] - if "mask" in name: - cur_tensor = cur_tensor.squeeze(0) - setattr(self, cur_name, cur_tensor) + if self.opt.isTrain: + for name in self.gen_visual_names: + whole_tensor = getattr(self, name[:-1]) + for k in range(min(nb_imgs, self.get_current_batch_size())): + cur_name = name + str(offset + k) + cur_tensor = whole_tensor[k : k + 1] + if "mask" in name: + cur_tensor = cur_tensor.squeeze(0) + setattr(self, cur_name, cur_tensor) def compute_visuals(self, nb_imgs): super().compute_visuals(nb_imgs)