Skip to content

Commit 401ea4a

Browse files
IamTingTingKumoLiuericspod
authored
Fix hardcoded input dim in DiffusionModelEncoder (#8514)
Fixes #8496 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: IamTingTing <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent fd13c1b commit 401ea4a

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

monai/networks/nets/diffusion_model_unet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import math
3535
from collections.abc import Sequence
36+
from typing import Optional
3637

3738
import torch
3839
from torch import nn
@@ -2006,7 +2007,7 @@ def __init__(
20062007

20072008
self.down_blocks.append(down_block)
20082009

2009-
self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))
2010+
self.out: Optional[nn.Module] = None
20102011

20112012
def forward(
20122013
self,
@@ -2049,6 +2050,12 @@ def forward(
20492050
h, _ = downsample_block(hidden_states=h, temb=emb, context=context)
20502051

20512052
h = h.reshape(h.shape[0], -1)
2053+
2054+
# 5. out
2055+
if self.out is None:
2056+
self.out = nn.Sequential(
2057+
nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
2058+
)
20522059
output: torch.Tensor = self.out(h)
20532060

20542061
return output

0 commit comments

Comments
 (0)