Skip to content

Commit cd264de

Browse files
committed
fix: diffusion with input and output of different channel size
1 parent 5046c83 commit cd264de

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

models/diffusion_networks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def define_G(
9090
norm_layer = get_norm_layer(norm_type=G_norm)
9191

9292
if model_type == "palette":
93-
in_channel = model_input_nc * 2
93+
in_channel = model_input_nc + model_output_nc
9494
else: # CM
9595
in_channel = model_input_nc
9696
if (

models/modules/diffusion_generator.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,15 @@ def restoration_ddpm(
138138
), "num_timesteps must greater than sample_num"
139139
sample_inter = self.denoise_fn.model.num_timesteps_test // sample_num
140140

141-
y_t = self.default(y_t, lambda: torch.randn_like(y_cond))
141+
# y_t must be of output channel size, since we do not have y_0 (gt), we get it from the model
142+
y_t_shape = list(y_cond.shape)
143+
y_t_shape[1] = (
144+
self.denoise_fn.model.out_channel
145+
) # set to number of model output channels
146+
y_t = self.default(
147+
y_t,
148+
lambda: torch.randn(y_t_shape, device=y_cond.device, dtype=y_cond.dtype),
149+
)
142150
ret_arr = y_t
143151

144152
for i in tqdm(
@@ -439,6 +447,8 @@ def ddim_p_mean_variance(
439447

440448
def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
441449
sequence_length = 0
450+
451+
# vid only
442452
if len(y_0.shape) == 5:
443453
sequence_length = y_0.shape[1]
444454
y_0, y_cond, mask = rearrange_5dto4d(y_0, y_cond, mask)

0 commit comments

Comments
 (0)