Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Image classification csv data source #556

Merged
merged 8 commits into from
Jul 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method ([#552](https://github.com/PyTorchLightning/lightning-flash/pull/552))
- Added support for `from_csv` and `from_data_frame` to `ImageClassificationData` ([#556](https://github.com/PyTorchLightning/lightning-flash/pull/556))

### Changed

Expand Down
296 changes: 293 additions & 3 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import glob
import os
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import torch
from pytorch_lightning.trainer.states import RunningStage
from torch.utils.data.sampler import Sampler

from flash.core.data.base_viz import BaseVisualization # for viz
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, LabelsState
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _requires_extras
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _requires_extras, _TORCHVISION_AVAILABLE
from flash.image.classification.transforms import default_transforms, train_default_transforms
from flash.image.data import (
ImageDeserializer,
Expand All @@ -37,6 +42,9 @@
else:
plt = None

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader

if _PIL_AVAILABLE:
from PIL import Image
else:
Expand All @@ -45,6 +53,96 @@ class Image:
Image = None


class ImageClassificationDataFrameDataSource(
DataSource[Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str]]]
):

@staticmethod
def _resolve_file(root: str, file_id: str) -> str:
if os.path.isabs(file_id):
pattern = f"{file_id}*"
else:
pattern = os.path.join(root, f"*{file_id}*")
files = glob.glob(pattern)
if len(files) > 1:
raise ValueError(
f"Found multiple matches for pattern: {pattern}. File IDs should uniquely identify the file to load."
)
elif len(files) == 0:
raise ValueError(
f"Found no matches for pattern: {pattern}. File IDs should uniquely identify the file to load."
)
return files[0]

@staticmethod
def _resolve_target(label_to_class: Dict[str, int], target_key: str, row: pd.Series) -> pd.Series:
row[target_key] = label_to_class[row[target_key]]
return row

@staticmethod
def _resolve_multi_target(target_keys: List[str], row: pd.Series) -> pd.Series:
row[target_keys[0]] = [row[target_key] for target_key in target_keys]
return row

def load_data(
self,
data: Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str]],
dataset: Optional[Any] = None,
) -> Sequence[Mapping[str, Any]]:
data_frame, input_key, target_keys, root = data
if root is None:
root = ""

if not self.predicting:
if isinstance(target_keys, List):
dataset.num_classes = len(target_keys)
self.set_state(LabelsState(target_keys))
data_frame = data_frame.apply(partial(self._resolve_multi_target, target_keys), axis=1)
target_keys = target_keys[0]
else:
if self.training:
labels = list(sorted(data_frame[target_keys].unique()))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))

labels = self.get_state(LabelsState)

if labels is not None:
labels = labels.labels
label_to_class = {v: k for k, v in enumerate(labels)}
data_frame = data_frame.apply(partial(self._resolve_target, label_to_class, target_keys), axis=1)

return [{
DefaultDataKeys.INPUT: row[input_key],
DefaultDataKeys.TARGET: row[target_keys],
DefaultDataKeys.METADATA: dict(root=root),
} for _, row in data_frame.iterrows()]
else:
return [{
DefaultDataKeys.INPUT: row[input_key],
DefaultDataKeys.METADATA: dict(root=root),
} for _, row in data_frame.iterrows()]

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
file = self._resolve_file(sample[DefaultDataKeys.METADATA]['root'], sample[DefaultDataKeys.INPUT])
sample[DefaultDataKeys.INPUT] = default_loader(file)
return sample


class ImageClassificationCSVDataSource(ImageClassificationDataFrameDataSource):

def load_data(
self,
data: Tuple[str, str, Union[str, List[str]], Optional[str]],
dataset: Optional[Any] = None,
) -> Sequence[Mapping[str, Any]]:
csv_file, input_key, target_keys, root = data
data_frame = pd.read_csv(csv_file)
if root is None:
root = os.path.dirname(csv_file)
return super().load_data((data_frame, input_key, target_keys, root), dataset)


class ImageClassificationPreprocess(Preprocess):

def __init__(
Expand All @@ -70,6 +168,8 @@ def __init__(
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSORS: ImageTensorDataSource(),
"data_frame": ImageClassificationDataFrameDataSource(),
DefaultDataSources.CSV: ImageClassificationCSVDataSource(),
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
Expand All @@ -94,6 +194,196 @@ class ImageClassificationData(DataModule):

preprocess_cls = ImageClassificationPreprocess

@classmethod
def from_data_frame(
cls,
input_field: str,
target_fields: Optional[Union[str, Sequence[str]]] = None,
train_data_frame: Optional[pd.DataFrame] = None,
train_images_root: Optional[str] = None,
val_data_frame: Optional[pd.DataFrame] = None,
val_images_root: Optional[str] = None,
test_data_frame: Optional[pd.DataFrame] = None,
test_images_root: Optional[str] = None,
predict_data_frame: Optional[pd.DataFrame] = None,
predict_images_root: Optional[str] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given pandas
``DataFrame`` objects.
Args:
input_field: The field (column) in the pandas ``DataFrame`` to use for the input.
target_fields: The field or fields (columns) in the pandas ``DataFrame`` to use for the target.
train_data_frame: The pandas ``DataFrame`` containing the training data.
train_images_root: The directory containing the train images. If ``None``, values in the ``input_field``
will be assumed to be the full file paths.
val_data_frame: The pandas ``DataFrame`` containing the validation data.
val_images_root: The directory containing the validation images. If ``None``, the directory containing the
``val_file`` will be used.
test_data_frame: The pandas ``DataFrame`` containing the testing data.
test_images_root: The directory containing the test images. If ``None``, the directory containing the
``test_file`` will be used.
predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting.
predict_images_root: The directory containing the predict images. If ``None``, the directory containing the
``predict_file`` will be used.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
Examples::
data_module = ImageClassificationData.from_data_frame(
"image_id",
"target",
train_data_frame=train_data,
train_images_root="data/train_images",
)
"""
return cls.from_data_source(
"data_frame",
(train_data_frame, input_field, target_fields, train_images_root),
(val_data_frame, input_field, target_fields, val_images_root),
(test_data_frame, input_field, target_fields, test_images_root),
(predict_data_frame, input_field, target_fields, predict_images_root),
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

@classmethod
def from_csv(
cls,
input_field: str,
target_fields: Optional[Union[str, Sequence[str]]] = None,
train_file: Optional[str] = None,
train_images_root: Optional[str] = None,
val_file: Optional[str] = None,
val_images_root: Optional[str] = None,
test_file: Optional[str] = None,
test_images_root: Optional[str] = None,
predict_file: Optional[str] = None,
predict_images_root: Optional[str] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given CSV files
using the :class:`~flash.core.data.data_source.DataSource`
of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV`
from the passed or constructed :class:`~flash.core.data.process.Preprocess`.
Args:
input_field: The field (column) in the CSV file to use for the input.
target_fields: The field or fields (columns) in the CSV file to use for the target.
train_file: The CSV file containing the training data.
train_images_root: The directory containing the train images. If ``None``, the directory containing the
``train_file`` will be used.
val_file: The CSV file containing the validation data.
val_images_root: The directory containing the validation images. If ``None``, the directory containing the
``val_file`` will be used.
test_file: The CSV file containing the testing data.
test_images_root: The directory containing the test images. If ``None``, the directory containing the
``test_file`` will be used.
predict_file: The CSV file containing the data to use when predicting.
predict_images_root: The directory containing the predict images. If ``None``, the directory containing the
``predict_file`` will be used.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
Examples::
data_module = ImageClassificationData.from_csv(
"image_id",
"target",
train_file="train_data.csv",
train_images_root="data/train_images",
)
"""
return cls.from_data_source(
DefaultDataSources.CSV,
(train_file, input_field, target_fields, train_images_root),
(val_file, input_field, target_fields, val_images_root),
(test_file, input_field, target_fields, test_images_root),
(predict_file, input_field, target_fields, predict_images_root),
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
metrics=metrics or F1(num_classes) if multi_label else Accuracy(),
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer or Labels(),
serializer=serializer or Labels(multi_label=multi_label),
)

self.save_hyperparameters()
Expand Down
Loading