Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
added .csv image loading utils (#118)
Browse files Browse the repository at this point in the history
* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils
  • Loading branch information
williamFalcon authored Feb 14, 2021
1 parent 0c160a2 commit b41e33d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
9 changes: 8 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -126,7 +127,7 @@ def predict(
Args:
x: Input to predict. Can be raw data or processed data.
x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data.
batch_idx: Batch index
Expand All @@ -142,6 +143,12 @@ def predict(
The post-processed model predictions
"""
# enable x to be a path to a folder
if isinstance(x, str):
files = os.listdir(x)
files = [os.path.join(x, y) for y in files]
x = files

data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None)
Expand Down
12 changes: 8 additions & 4 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pandas as pd
import torch
from PIL import Image
from PIL import Image, UnidentifiedImageError
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torchvision import transforms as T
from torchvision.datasets import VisionDataset
Expand Down Expand Up @@ -241,9 +241,13 @@ def before_collate(self, samples: Any) -> Any:
if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples):
outputs = []
for sample in samples:
output = self._loader(sample)
transform = self._valid_transform if self._use_valid_transform else self._train_transform
outputs.append(transform(output))
try:
output = self._loader(sample)
transform = self._valid_transform if self._use_valid_transform else self._train_transform
outputs.append(transform(output))
except UnidentifiedImageError:
print(f'Skipping: could not read file {sample}')

return outputs
raise MisconfigurationException("The samples should either be a tensor or a list of paths.")

Expand Down
19 changes: 19 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Any

import numpy as np
import pytest
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import nn
from torch.nn import functional as F

Expand Down Expand Up @@ -68,6 +71,18 @@ def test_classificationtask_task_predict():
assert pred0[0] == pred1[0]


def test_classification_task_predict_folder_path(tmpdir):
train_dir = Path(tmpdir / "train")
train_dir.mkdir()

_rand_image().save(train_dir / "1.png")
_rand_image().save(train_dir / "2.png")

task = ImageClassifier(num_classes=10)
predictions = task.predict(str(train_dir))
assert len(predictions) == 2


def test_classificationtask_trainer_predict(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
task = ClassificationTask(model)
Expand Down Expand Up @@ -127,3 +142,7 @@ def test_model_download(tmpdir, cls, filename):
with tmpdir.as_cwd():
task = cls.load_from_checkpoint(url + filename)
assert isinstance(task, cls)


def _rand_image():
return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))

0 comments on commit b41e33d

Please sign in to comment.