From 9b86a0ffbf849a33fd4851b118d128702c0e6ab7 Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Tue, 17 Aug 2021 06:04:46 -0400 Subject: [PATCH] read_image to default_loader (#669) * read_image to default_loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * imports Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- flash/image/segmentation/data.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index f96573e262..8ee8382002 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -63,7 +63,8 @@ if _TORCHVISION_AVAILABLE: import torchvision - from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS + import torchvision.transforms.functional as FT + from torchvision.datasets.folder import default_loader, has_file_allowed_extension, IMG_EXTENSIONS else: IMG_EXTENSIONS = None @@ -148,7 +149,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten img_labels_path = sample[DefaultDataKeys.TARGET] # load images directly to torch tensors - img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW @@ -163,7 +164,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = torchvision.io.read_image(img_path).float() + img = FT.to_tensor(default_loader(img_path)).float() sample[DefaultDataKeys.INPUT] = img sample[DefaultDataKeys.METADATA] = { @@ -195,7 +196,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten img_path = sample[DefaultDataKeys.INPUT] fo_sample = _fo_dataset[img_path] - img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torch.from_numpy(fo_sample[self.label_field].mask) # HxW sample[DefaultDataKeys.INPUT] = img.float() @@ -209,7 +210,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = torchvision.io.read_image(img_path).float() + img = FT.to_tensor(default_loader(img_path)).float() sample[DefaultDataKeys.INPUT] = img sample[DefaultDataKeys.METADATA] = {