diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 0c48ff6014..1a7be19e05 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -17,6 +17,7 @@ import types from importlib.util import find_spec from typing import Callable, List, Union +from warnings import warn from pkg_resources import DistributionNotFound @@ -99,6 +100,25 @@ def _compare_version(package: str, op, version) -> bool: _ICEDATA_AVAILABLE = _module_available("icedata") _TORCH_ORT_AVAILABLE = _module_available("torch_ort") +if _PIL_AVAILABLE: + from PIL import Image +else: + + class MetaImage(type): + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + + cls._Image = None + + @property + def Image(cls): + warn("Mock object called due to missing PIL library. Please use \"pip install 'lightning-flash[image]'\".") + return cls._Image + + class Image(metaclass=MetaImage): + pass + + if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 19215b02e6..32ed049ce6 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -24,7 +24,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LoaderDataFrameDataSource from flash.core.data.process import Deserializer, Preprocess -from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, requires, requires_extras +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires, requires_extras from flash.image.classification.transforms import default_transforms, train_default_transforms from flash.image.data import ( image_loader, @@ -40,13 +40,6 @@ else: plt = None -if _PIL_AVAILABLE: - from PIL import Image -else: - - class Image: - Image = None - class ImageClassificationDataFrameDataSource(LoaderDataFrameDataSource): @requires_extras("image") diff --git a/flash/image/data.py b/flash/image/data.py index b2ea2e3fa1..45d7f2af6c 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -29,7 +29,7 @@ TensorDataSource, ) from flash.core.data.process import Deserializer -from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, requires_extras +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires_extras if _TORCHVISION_AVAILABLE: import torchvision @@ -38,13 +38,6 @@ else: IMG_EXTENSIONS = () -if _PIL_AVAILABLE: - from PIL import Image as PILImage -else: - - class Image: - Image = None - NP_EXTENSIONS = (".npy", ".npz") @@ -53,7 +46,7 @@ def image_loader(filepath: str): if has_file_allowed_extension(filepath, IMG_EXTENSIONS): img = default_loader(filepath) elif has_file_allowed_extension(filepath, NP_EXTENSIONS): - img = PILImage.fromarray(np.load(filepath).astype("uint8"), "RGB") + img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB") else: raise ValueError( f"File: {filepath} has an unsupported extension. Supported extensions: " @@ -72,7 +65,7 @@ def deserialize(self, data: str) -> Dict: encoded_with_padding = (data + "===").encode("ascii") img = base64.b64decode(encoded_with_padding) buffer = BytesIO(img) - img = PILImage.open(buffer, mode="r") + img = Image.open(buffer, mode="r") return { DefaultDataKeys.INPUT: img, } diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 8ee8382002..6b39ee1450 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -38,8 +38,8 @@ from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _MATPLOTLIB_AVAILABLE, - _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, + Image, lazy_import, requires, requires_extras, @@ -68,13 +68,6 @@ else: IMG_EXTENSIONS = None -if _PIL_AVAILABLE: - from PIL import Image -else: - - class Image: - Image = None - class SemanticSegmentationNumpyDataSource(NumpyDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index e16d62e686..3d3b53b111 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -32,7 +32,7 @@ from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess -from flash.core.utilities.imports import _PIL_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE +from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, Image from flash.image import ImageClassificationData, ImageClassifier from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING @@ -41,13 +41,6 @@ else: TabularClassifier = None -if _PIL_AVAILABLE: - from PIL import Image -else: - - class Image: - Image = None - # ======== Mock functions ========