From 16243cc24c4afae196d534bce26a0709cbc91c92 Mon Sep 17 00:00:00 2001 From: mansourhas <37121036+mansourhas@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:56:29 +0300 Subject: [PATCH] Update seg_to_regions.py fix error message ` isin() received an invalid combination of arguments - got (Tensor, tuple), ` --- batchgeneratorsv2/transforms/utils/seg_to_regions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batchgeneratorsv2/transforms/utils/seg_to_regions.py b/batchgeneratorsv2/transforms/utils/seg_to_regions.py index efde449..1fa1a13 100644 --- a/batchgeneratorsv2/transforms/utils/seg_to_regions.py +++ b/batchgeneratorsv2/transforms/utils/seg_to_regions.py @@ -17,7 +17,7 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch. if len(region_labels) == 1: region_output[region_id] = segmentation[self.channel_in_seg] == region_labels else: - region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], region_labels) + region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], torch.tensor(region_labels).to(segmentation.dtype)) # we return bool here and leave it to the loss function to cast it to whatever it needs. Transferring bool to # device followed by cast on device should be faster than having fp32 here and transferring that return region_output