Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

added .csv image loading utils #118

Merged
merged 6 commits into from
Feb 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -126,7 +127,7 @@ def predict(

Args:

x: Input to predict. Can be raw data or processed data.
x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data.

batch_idx: Batch index

Expand All @@ -142,6 +143,12 @@ def predict(
The post-processed model predictions

"""
# enable x to be a path to a folder
if isinstance(x, str):
files = os.listdir(x)
files = [os.path.join(x, y) for y in files]
x = files

data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None)
Expand Down
12 changes: 8 additions & 4 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pandas as pd
import torch
from PIL import Image
from PIL import Image, UnidentifiedImageError
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torchvision import transforms as T
from torchvision.datasets import VisionDataset
Expand Down Expand Up @@ -241,9 +241,13 @@ def before_collate(self, samples: Any) -> Any:
if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples):
outputs = []
for sample in samples:
output = self._loader(sample)
transform = self._valid_transform if self._use_valid_transform else self._train_transform
outputs.append(transform(output))
try:
output = self._loader(sample)
transform = self._valid_transform if self._use_valid_transform else self._train_transform
outputs.append(transform(output))
except UnidentifiedImageError:
print(f'Skipping: could not read file {sample}')

return outputs
raise MisconfigurationException("The samples should either be a tensor or a list of paths.")

Expand Down
19 changes: 19 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Any

import numpy as np
import pytest
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import nn
from torch.nn import functional as F

Expand Down Expand Up @@ -68,6 +71,18 @@ def test_classificationtask_task_predict():
assert pred0[0] == pred1[0]


def test_classification_task_predict_folder_path(tmpdir):
train_dir = Path(tmpdir / "train")
train_dir.mkdir()

_rand_image().save(train_dir / "1.png")
_rand_image().save(train_dir / "2.png")

task = ImageClassifier(num_classes=10)
predictions = task.predict(str(train_dir))
assert len(predictions) == 2


def test_classificationtask_trainer_predict(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
task = ClassificationTask(model)
Expand Down Expand Up @@ -127,3 +142,7 @@ def test_model_download(tmpdir, cls, filename):
with tmpdir.as_cwd():
task = cls.load_from_checkpoint(url + filename)
assert isinstance(task, cls)


def _rand_image():
return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))