Skip to content

Commit

Permalink
fix: consistency models with input/output different channels
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 12, 2024
1 parent d05b337 commit db61821
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions models/cm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def set_input(self, data):
self.real_B = self.gt_image

def compute_cm_loss(self):

y_0 = self.gt_image # ground truth
y_cond = self.cond_image # conditioning
mask = self.mask
Expand Down Expand Up @@ -306,7 +305,6 @@ def inference(self, nb_imgs, offset=0):
netG = revert_sync_batchnorm(netG)

# XXX: inpainting only for now

if self.mask is not None:
mask = self.mask[:nb_imgs]
else:
Expand All @@ -318,9 +316,15 @@ def inference(self, nb_imgs, offset=0):
y_cond = self.cond_image[:nb_imgs]
else:
y_cond = None
self.output = netG.restoration(
self.y_t[:nb_imgs], y_cond, sampling_sigmas, mask
)
if (
self.task == "pix2pix"
): # y_t must be of output channel size, since we do not have y_0 (gt), we get it from the model
out_shape = list(y_cond.shape)
out_shape[1] = netG.cm_model.out_channel
y_t = torch.zeros(out_shape, device=y_cond.device, dtype=y_cond.dtype)
else: # e.g. inpainting
y_t = self.y_t[:nb_imgs]
self.output = netG.restoration(y_t, y_cond, sampling_sigmas, mask)
self.fake_B = self.output
self.visuals = self.output

Expand Down

0 comments on commit db61821

Please sign in to comment.