diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index f96573e262..8ee8382002 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -63,7 +63,8 @@ if _TORCHVISION_AVAILABLE: import torchvision - from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS + import torchvision.transforms.functional as FT + from torchvision.datasets.folder import default_loader, has_file_allowed_extension, IMG_EXTENSIONS else: IMG_EXTENSIONS = None @@ -148,7 +149,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten img_labels_path = sample[DefaultDataKeys.TARGET] # load images directly to torch tensors - img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW @@ -163,7 +164,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = torchvision.io.read_image(img_path).float() + img = FT.to_tensor(default_loader(img_path)).float() sample[DefaultDataKeys.INPUT] = img sample[DefaultDataKeys.METADATA] = { @@ -195,7 +196,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten img_path = sample[DefaultDataKeys.INPUT] fo_sample = _fo_dataset[img_path] - img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torch.from_numpy(fo_sample[self.label_field].mask) # HxW sample[DefaultDataKeys.INPUT] = img.float() @@ -209,7 +210,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = torchvision.io.read_image(img_path).float() + img = FT.to_tensor(default_loader(img_path)).float() sample[DefaultDataKeys.INPUT] = img sample[DefaultDataKeys.METADATA] = {