diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index a747bfbe1b..6d62878cb5 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -194,7 +194,7 @@ class SemanticSegmentationDeserializer(ImageDeserializer): def deserialize(self, data: str) -> Dict[str, Any]: result = super().deserialize(data) result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT]) - result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape} + result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape[-2:]} return result