Skip to content

Commit 95e70fa

Browse files
authored
🐛 Fix MapDe dist_filter Shape (#914)
- Fix `dist_filter` in `MapDe` model for multi-class output. Explanation: Previously, if we set `num_class` to more than 1, the model would still output 1 channel. This was because the `dist_filter` always had size of 1 in its first dimension, however the first dimension determines the number of output channels in the tensor produced by `torch.functional.F.conv2d`. This PR changes this by repeating the filters the match the number of output classes.
1 parent 9021b57 commit 95e70fa

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tests/models/test_arch_mapde.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,12 @@ def test_functionality(remote_sample: Callable) -> None:
4848
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
4949
output = model.postproc(output[0])
5050
assert np.all(output[0:2] == [[19, 171], [53, 89]])
51+
52+
53+
def test_multiclass_output() -> None:
54+
"""Test the architecture for multi-class output."""
55+
multiclass_model = MapDe(num_input_channels=3, num_classes=3)
56+
test_input = torch.rand((1, 3, 252, 252))
57+
58+
output = multiclass_model(test_input)
59+
assert output.shape == (1, 3, 252, 252)

tiatoolbox/models/architecture/mapde.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,11 @@ def __init__(
199199
dtype=np.float32,
200200
)
201201

202-
dist_filter = np.expand_dims(dist_filter, axis=(0, 1)) # NCHW
202+
# For conv2d, filter shape = (out_channels, in_channels//groups, H, W)
203+
dist_filter = np.expand_dims(dist_filter, axis=(0, 1))
203204
dist_filter = np.repeat(dist_filter, repeats=num_classes * 2, axis=1)
205+
# Need to repeat for out_channels
206+
dist_filter = np.repeat(dist_filter, repeats=num_classes, axis=0)
204207

205208
self.min_distance = min_distance
206209
self.threshold_abs = threshold_abs

0 commit comments

Comments
 (0)