diff --git a/batchgeneratorsv2/transforms/utils/seg_to_regions.py b/batchgeneratorsv2/transforms/utils/seg_to_regions.py index efde449..1fa1a13 100644 --- a/batchgeneratorsv2/transforms/utils/seg_to_regions.py +++ b/batchgeneratorsv2/transforms/utils/seg_to_regions.py @@ -17,7 +17,7 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch. if len(region_labels) == 1: region_output[region_id] = segmentation[self.channel_in_seg] == region_labels else: - region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], region_labels) + region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], torch.tensor(region_labels).to(segmentation.dtype)) # we return bool here and leave it to the loss function to cast it to whatever it needs. Transferring bool to # device followed by cast on device should be faster than having fp32 here and transferring that return region_output