Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'feature/resnet_in_chans' of https://github.com/PyTorchL…
Browse files Browse the repository at this point in the history
…ightning/lightning-flash into feature/resnet_in_chans
  • Loading branch information
ethanwharris committed Aug 17, 2021
2 parents fb9d66e + 54927a4 commit c7b5c2f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 34 deletions.
20 changes: 20 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import types
from importlib.util import find_spec
from typing import Callable, List, Union
from warnings import warn

from pkg_resources import DistributionNotFound

Expand Down Expand Up @@ -99,6 +100,25 @@ def _compare_version(package: str, op, version) -> bool:
_ICEDATA_AVAILABLE = _module_available("icedata")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")

if _PIL_AVAILABLE:
from PIL import Image
else:

class MetaImage(type):
def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)

cls._Image = None

@property
def Image(cls):
warn("Mock object called due to missing PIL library. Please use \"pip install 'lightning-flash[image]'\".")
return cls._Image

class Image(metaclass=MetaImage):
pass


if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")

Expand Down
9 changes: 1 addition & 8 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LoaderDataFrameDataSource
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, requires, requires_extras
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires, requires_extras
from flash.image.classification.transforms import default_transforms, train_default_transforms
from flash.image.data import (
image_loader,
Expand All @@ -40,13 +40,6 @@
else:
plt = None

if _PIL_AVAILABLE:
from PIL import Image
else:

class Image:
Image = None


class ImageClassificationDataFrameDataSource(LoaderDataFrameDataSource):
@requires_extras("image")
Expand Down
13 changes: 3 additions & 10 deletions flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
TensorDataSource,
)
from flash.core.data.process import Deserializer
from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, requires_extras
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires_extras

if _TORCHVISION_AVAILABLE:
import torchvision
Expand All @@ -38,13 +38,6 @@
else:
IMG_EXTENSIONS = ()

if _PIL_AVAILABLE:
from PIL import Image as PILImage
else:

class Image:
Image = None


NP_EXTENSIONS = (".npy", ".npz")

Expand All @@ -53,7 +46,7 @@ def image_loader(filepath: str):
if has_file_allowed_extension(filepath, IMG_EXTENSIONS):
img = default_loader(filepath)
elif has_file_allowed_extension(filepath, NP_EXTENSIONS):
img = PILImage.fromarray(np.load(filepath).astype("uint8"), "RGB")
img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB")
else:
raise ValueError(
f"File: {filepath} has an unsupported extension. Supported extensions: "
Expand All @@ -72,7 +65,7 @@ def deserialize(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
img = Image.open(buffer, mode="r")
return {
DefaultDataKeys.INPUT: img,
}
Expand Down
9 changes: 1 addition & 8 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from flash.core.utilities.imports import (
_FIFTYONE_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
_PIL_AVAILABLE,
_TORCHVISION_AVAILABLE,
Image,
lazy_import,
requires,
requires_extras,
Expand Down Expand Up @@ -68,13 +68,6 @@
else:
IMG_EXTENSIONS = None

if _PIL_AVAILABLE:
from PIL import Image
else:

class Image:
Image = None


class SemanticSegmentationNumpyDataSource(NumpyDataSource):
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
Expand Down
9 changes: 1 addition & 8 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from flash.core.adapter import Adapter
from flash.core.classification import ClassificationTask
from flash.core.data.process import DefaultPreprocess, Postprocess
from flash.core.utilities.imports import _PIL_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE
from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, Image
from flash.image import ImageClassificationData, ImageClassifier
from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING

Expand All @@ -41,13 +41,6 @@
else:
TabularClassifier = None

if _PIL_AVAILABLE:
from PIL import Image
else:

class Image:
Image = None


# ======== Mock functions ========

Expand Down

0 comments on commit c7b5c2f

Please sign in to comment.