Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

read_image to default_loader #669

Merged
merged 5 commits into from
Aug 17, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
import torchvision.transforms.functional as FT
from torchvision.datasets.folder import default_loader, has_file_allowed_extension, IMG_EXTENSIONS
else:
IMG_EXTENSIONS = None

Expand Down Expand Up @@ -148,7 +149,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten
img_labels_path = sample[DefaultDataKeys.TARGET]

# load images directly to torch tensors
img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW
img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW
Copy link
Collaborator

@ethanwharris ethanwharris Aug 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would maybe be better to just use default_loader here and then convert to tensor in the preprocess to_tensor_transform. Could make a good follow up PR 😃

img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW
img_labels = img_labels[0] # HxW

Expand All @@ -163,7 +164,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten
@staticmethod
def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]:
img_path = sample[DefaultDataKeys.INPUT]
img = torchvision.io.read_image(img_path).float()
img = FT.to_tensor(default_loader(img_path)).float()

sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = {
Expand Down Expand Up @@ -195,7 +196,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten
img_path = sample[DefaultDataKeys.INPUT]
fo_sample = _fo_dataset[img_path]

img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW
img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW
img_labels: torch.Tensor = torch.from_numpy(fo_sample[self.label_field].mask) # HxW

sample[DefaultDataKeys.INPUT] = img.float()
Expand All @@ -209,7 +210,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten
@staticmethod
def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]:
img_path = sample[DefaultDataKeys.INPUT]
img = torchvision.io.read_image(img_path).float()
img = FT.to_tensor(default_loader(img_path)).float()

sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = {
Expand Down