diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index 2a65583c6..4ec404826 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -48,3 +48,12 @@ def test_functionality(remote_sample: Callable) -> None: output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) + + +def test_multiclass_output() -> None: + """Test the architecture for multi-class output.""" + multiclass_model = MapDe(num_input_channels=3, num_classes=3) + test_input = torch.rand((1, 3, 252, 252)) + + output = multiclass_model(test_input) + assert output.shape == (1, 3, 252, 252) diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py index bbb468bb8..0900aa6fd 100644 --- a/tiatoolbox/models/architecture/mapde.py +++ b/tiatoolbox/models/architecture/mapde.py @@ -199,8 +199,11 @@ def __init__( dtype=np.float32, ) - dist_filter = np.expand_dims(dist_filter, axis=(0, 1)) # NCHW + # For conv2d, filter shape = (out_channels, in_channels//groups, H, W) + dist_filter = np.expand_dims(dist_filter, axis=(0, 1)) dist_filter = np.repeat(dist_filter, repeats=num_classes * 2, axis=1) + # Need to repeat for out_channels + dist_filter = np.repeat(dist_filter, repeats=num_classes, axis=0) self.min_distance = min_distance self.threshold_abs = threshold_abs