Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Aug 17, 2021
1 parent 3c7e2db commit 4931cd7
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
plt = None

if _TORCHVISION_AVAILABLE:
import torchvision
import torchvision.transforms.functional as FT
from torchvision.datasets.folder import default_loader, has_file_allowed_extension, IMG_EXTENSIONS
else:
Expand Down Expand Up @@ -149,7 +150,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten

# load images directly to torch tensors
img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW
img_labels: torch.Tensor = FT.to_tensor(default_loader(img_labels_path)) # CxHxW
img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW
img_labels = img_labels[0] # HxW

sample[DefaultDataKeys.INPUT] = img.float()
Expand Down

0 comments on commit 4931cd7

Please sign in to comment.