Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jul 16, 2021
1 parent 4cf522e commit 5802dcf
Show file tree
Hide file tree
Showing 11 changed files with 467 additions and 197 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
logs/cache/*
flash_examples/data
flash_examples/checkpoints
5 changes: 3 additions & 2 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,13 @@ def _predict_dataloader(self) -> DataLoader:
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
return self.trainer.lightning_module.process_predict_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
collate_fn=collate_fn,
convert_to_dataloader=True,
)

return DataLoader(
Expand Down
5 changes: 5 additions & 0 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def _create_collate_preprocessors(
prefix: str = _STAGES_PREFIX[stage]

if collate_fn is not None:
preprocess._original_default_collate = preprocess._default_collate
preprocess._default_collate = collate_fn

func_names: Dict[str, str] = {
Expand Down Expand Up @@ -486,6 +487,10 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin
elif isinstance(stage, RunningStage):
stages = [stage]

self._preprocess_pipeline._default_collate = getattr(
self._preprocess_pipeline, "_original_default_collate", self._preprocess_pipeline._default_collate
)

for stage in stages:

device_collate = None
Expand Down
4 changes: 1 addition & 3 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Union

from pytorch_lightning.utilities import rank_zero_info
Expand Down Expand Up @@ -76,7 +74,7 @@ def _register_function(
override: bool = False,
metadata: Optional[Dict[str, Any]] = None
):
if not isinstance(fn, FunctionType) and not isinstance(fn, partial):
if not callable(fn):
raise MisconfigurationException(f"You can only register a function, found: {fn}")

name = name or fn.__name__
Expand Down
2 changes: 2 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
_ICEVISION_AVAILABLE = _module_available("icevision")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
Expand All @@ -103,6 +104,7 @@ def _compare_version(package: str, op, version) -> bool:
_KORNIA_AVAILABLE,
_PYSTICHE_AVAILABLE,
_SEGMENTATION_MODELS_AVAILABLE,
_ICEVISION_AVAILABLE,
])
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE
Expand Down
186 changes: 112 additions & 74 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, TYPE_CHECKING

import numpy as np

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, FiftyOneDataSource
from flash.core.data.data_source import DefaultDataKeys, FiftyOneDataSource
from flash.core.data.process import Preprocess
from flash.core.utilities.imports import (
_COCO_AVAILABLE,
_FIFTYONE_AVAILABLE,
_ICEVISION_AVAILABLE,
_TORCHVISION_AVAILABLE,
lazy_import,
requires,
)
from flash.image.data import ImagePathsDataSource
from flash.image.detection.transforms import default_transforms

if _COCO_AVAILABLE:
from pycocotools.coco import COCO
pass

SampleCollection = None
if _FIFTYONE_AVAILABLE:
Expand All @@ -42,75 +43,102 @@
if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader

if _ICEVISION_AVAILABLE:
from icevision.core import BaseRecord, ClassMapRecordComponent, ImageRecordComponent, tasks
from icevision.data import SingleSplitSplitter
from icevision.parsers import Parser

class COCODataSource(DataSource[Tuple[str, str]]):

@requires("pycocotools")
def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
root, ann_file = data

coco = COCO(ann_file)

categories = coco.loadCats(coco.getCatIds())
if categories:
dataset.num_classes = categories[-1]["id"] + 1

img_ids = list(sorted(coco.imgs.keys()))
paths = coco.loadImgs(img_ids)

data = []

for img_id, path in zip(img_ids, paths):
path = path["file_name"]

ann_ids = coco.getAnnIds(imgIds=img_id)
annotations = coco.loadAnns(ann_ids)
class IceVisionPathsDataSource(ImagePathsDataSource):

boxes, labels, areas, iscrowd = [], [], [], []
def __init__(self, parser: Type[Parser]):
self.parser = parser

# Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py
if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations):
continue

for obj in annotations:
xmin = obj["bbox"][0]
ymin = obj["bbox"][1]
xmax = xmin + obj["bbox"][2]
ymax = ymin + obj["bbox"][3]
def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
root, ann_file = data

bbox = [xmin, ymin, xmax, ymax]
keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0])
if keep:
boxes.append(bbox)
labels.append(obj["category_id"])
areas.append(obj["area"])
iscrowd.append(obj["iscrowd"])

data.append(
dict(
input=os.path.join(root, path),
target=dict(
boxes=boxes,
labels=labels,
image_id=img_id,
area=areas,
iscrowd=iscrowd,
)
)
)
return data
parser = self.parser(ann_file, root)
dataset.num_classes = len(parser.class_map)
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]

def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
return super().predict_load_data(data, dataset)

# coco = COCO(ann_file)
#
# categories = coco.loadCats(coco.getCatIds())
# if categories:
# dataset.num_classes = categories[-1]["id"] + 1
#
# img_ids = list(sorted(coco.imgs.keys()))
# paths = coco.loadImgs(img_ids)
#
# data = []
#
# for img_id, path in zip(img_ids, paths):
# path = path["file_name"]
#
# ann_ids = coco.getAnnIds(imgIds=img_id)
# annotations = coco.loadAnns(ann_ids)
#
# boxes, labels, areas, iscrowd = [], [], [], []
#
# # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py
# if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations):
# continue
#
# for obj in annotations:
# xmin = obj["bbox"][0]
# ymin = obj["bbox"][1]
# xmax = xmin + obj["bbox"][2]
# ymax = ymin + obj["bbox"][3]
#
# bbox = [xmin, ymin, xmax, ymax]
# keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0])
# if keep:
# boxes.append(bbox)
# labels.append(obj["category_id"])
# areas.append(obj["area"])
# iscrowd.append(obj["iscrowd"])
#
# data.append(
# dict(
# input=os.path.join(root, path),
# target=dict(
# boxes=boxes,
# labels=labels,
# image_id=img_id,
# area=areas,
# iscrowd=iscrowd,
# )
# )
# )
# return data

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
filepath = sample[DefaultDataKeys.INPUT]
img = default_loader(filepath)
sample[DefaultDataKeys.INPUT] = img
w, h = img.size # WxH
sample[DefaultDataKeys.METADATA] = {
"filepath": filepath,
"size": (h, w),
}
return sample
return sample
# TODO: get image size for metadata
# sample[DefaultDataKeys.INPUT] = sample[DefaultDataKeys.INPUT].load()
return sample[DefaultDataKeys.INPUT].load()
# filepath = sample[DefaultDataKeys.INPUT]
# img = default_loader(filepath)
# sample[DefaultDataKeys.INPUT] = img
# w, h = img.size # WxH
# sample[DefaultDataKeys.METADATA] = {
# "filepath": filepath,
# "size": (h, w),
# }
# return sample
# return sample

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample = super().load_sample(sample)
image = np.array(sample[DefaultDataKeys.INPUT])
record = BaseRecord([ImageRecordComponent()])
# record.set_record_id(i)
record.set_img(image)
record.add_component(ClassMapRecordComponent(task=tasks.detection))
return record


class ObjectDetectionFiftyOneDataSource(FiftyOneDataSource):
Expand Down Expand Up @@ -205,22 +233,27 @@ def __init__(
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: Tuple[int, int] = (128, 128),
**data_source_kwargs: Any,
):
self.image_size = image_size

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs),
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
"coco": COCODataSource(),
# DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs),
# DefaultDataSources.FILES: ObjectDetectionPathsDataSource(),
# DefaultDataSources.FOLDERS: ObjectDetectionPathsDataSource(),
# "coco": COCODataSource(),
},
default_data_source=DefaultDataSources.FILES,
default_data_source="coco",
)

self._default_collate = self._identity

def get_state_dict(self) -> Dict[str, Any]:
return {**self.transforms}

Expand All @@ -229,15 +262,17 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms()
return default_transforms(self.image_size)

def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms(self.image_size)


class ObjectDetectionData(DataModule):

preprocess_cls = ObjectDetectionPreprocess

@classmethod
@requires("pycocotools")
def from_coco(
cls,
train_folder: Optional[str] = None,
Expand All @@ -246,9 +281,11 @@ def from_coco(
val_ann_file: Optional[str] = None,
test_folder: Optional[str] = None,
test_ann_file: Optional[str] = None,
predict_folder: Optional[str] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
Expand Down Expand Up @@ -298,6 +335,7 @@ def from_coco(
(train_folder, train_ann_file) if train_folder else None,
(val_folder, val_ann_file) if val_folder else None,
(test_folder, test_ann_file) if test_folder else None,
predict_folder,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
Expand Down
Loading

0 comments on commit 5802dcf

Please sign in to comment.