Skip to content

Commit

Permalink
fix: cm at test time
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 30, 2024
1 parent 280f85a commit 706356b
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions models/cm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ def __init__(self, opt, rank):
opt.alg_diffusion_cond_embed = opt.alg_diffusion_cond_image_creation
opt.alg_diffusion_cond_embed_dim = 256
self.netG_A = diffusion_networks.define_G(**vars(opt)).to(self.device)
self.netG_A.current_t = max(self.netG_A.current_t, opt.total_iters)
if opt.isTrain:
self.netG_A.current_t = max(self.netG_A.current_t, opt.total_iters)
else:
self.netG_A.current_t = 0 # placeholder
print("Setting CM current_iter to", self.netG_A.current_t)

self.model_names = ["G_A"]
Expand Down Expand Up @@ -329,14 +332,15 @@ def inference(self, nb_imgs, offset=0):
self.visuals = self.output

# set visual names
for name in self.gen_visual_names:
whole_tensor = getattr(self, name[:-1])
for k in range(min(nb_imgs, self.get_current_batch_size())):
cur_name = name + str(offset + k)
cur_tensor = whole_tensor[k : k + 1]
if "mask" in name:
cur_tensor = cur_tensor.squeeze(0)
setattr(self, cur_name, cur_tensor)
if self.opt.isTrain:
for name in self.gen_visual_names:
whole_tensor = getattr(self, name[:-1])
for k in range(min(nb_imgs, self.get_current_batch_size())):
cur_name = name + str(offset + k)
cur_tensor = whole_tensor[k : k + 1]
if "mask" in name:
cur_tensor = cur_tensor.squeeze(0)
setattr(self, cur_name, cur_tensor)

def compute_visuals(self, nb_imgs):
super().compute_visuals(nb_imgs)
Expand Down

0 comments on commit 706356b

Please sign in to comment.