From db618210a733e33aa0037f101019cd66850e145b Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Wed, 11 Sep 2024 15:11:20 +0000 Subject: [PATCH] fix: consistency models with input/output different channels --- models/cm_model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/models/cm_model.py b/models/cm_model.py index 85a20ed6e..4f27fc938 100644 --- a/models/cm_model.py +++ b/models/cm_model.py @@ -253,7 +253,6 @@ def set_input(self, data): self.real_B = self.gt_image def compute_cm_loss(self): - y_0 = self.gt_image # ground truth y_cond = self.cond_image # conditioning mask = self.mask @@ -306,7 +305,6 @@ def inference(self, nb_imgs, offset=0): netG = revert_sync_batchnorm(netG) # XXX: inpainting only for now - if self.mask is not None: mask = self.mask[:nb_imgs] else: @@ -318,9 +316,15 @@ def inference(self, nb_imgs, offset=0): y_cond = self.cond_image[:nb_imgs] else: y_cond = None - self.output = netG.restoration( - self.y_t[:nb_imgs], y_cond, sampling_sigmas, mask - ) + if ( + self.task == "pix2pix" + ): # y_t must be of output channel size, since we do not have y_0 (gt), we get it from the model + out_shape = list(y_cond.shape) + out_shape[1] = netG.cm_model.out_channel + y_t = torch.zeros(out_shape, device=y_cond.device, dtype=y_cond.dtype) + else: # e.g. inpainting + y_t = self.y_t[:nb_imgs] + self.output = netG.restoration(y_t, y_cond, sampling_sigmas, mask) self.fake_B = self.output self.visuals = self.output