Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

updated mock Image object #670

Merged
merged 10 commits into from
Aug 17, 2021
16 changes: 14 additions & 2 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import warn
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import numpy as np
Expand Down Expand Up @@ -44,8 +45,19 @@
from PIL import Image
else:

class Image:
Image = None
class MetaImage(type):
def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)

cls._Image = None

@property
def Image(cls):
warn("Mock object called due to missing PIL library. Install PIL using 'pip install Pillow'.")
return cls._Image

class Image(metaclass=MetaImage):
pass


class ImageClassificationDataFrameDataSource(LoaderDataFrameDataSource):
Expand Down
22 changes: 17 additions & 5 deletions flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import base64
from io import BytesIO
from logging import warn
from pathlib import Path
from typing import Any, Dict, Optional

Expand All @@ -39,11 +40,22 @@
IMG_EXTENSIONS = ()

if _PIL_AVAILABLE:
from PIL import Image as PILImage
from PIL import Image
else:

class Image:
Image = None
class MetaImage(type):
def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)

cls._Image = None

@property
def Image(cls):
warn("Mock object called due to missing PIL library. Install PIL using 'pip install Pillow'.")
return cls._Image

class Image(metaclass=MetaImage):
pass


NP_EXTENSIONS = (".npy", ".npz")
Expand All @@ -53,7 +65,7 @@ def image_loader(filepath: str):
if has_file_allowed_extension(filepath, IMG_EXTENSIONS):
img = default_loader(filepath)
elif has_file_allowed_extension(filepath, NP_EXTENSIONS):
img = PILImage.fromarray(np.load(filepath).astype("uint8"), "RGB")
img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB")
else:
raise ValueError(
f"File: {filepath} has an unsupported extension. Supported extensions: "
Expand All @@ -72,7 +84,7 @@ def deserialize(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
img = Image.open(buffer, mode="r")
return {
DefaultDataKeys.INPUT: img,
}
Expand Down
16 changes: 14 additions & 2 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from logging import warn
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np
Expand Down Expand Up @@ -71,8 +72,19 @@
from PIL import Image
else:

class Image:
Image = None
class MetaImage(type):
def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)

cls._Image = None

@property
def Image(cls):
warn("Mock object called due to missing PIL library. Install PIL using 'pip install Pillow'.")
return cls._Image

class Image(metaclass=MetaImage):
pass


class SemanticSegmentationNumpyDataSource(NumpyDataSource):
Expand Down
16 changes: 14 additions & 2 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from logging import warn
from numbers import Number
from pathlib import Path
from typing import Any, Tuple
Expand Down Expand Up @@ -44,8 +45,19 @@
from PIL import Image
else:

class Image:
Image = None
class MetaImage(type):
def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)

cls._Image = None

@property
def Image(cls):
warn("Mock object called due to missing PIL library. Install PIL using 'pip install Pillow'.")
return cls._Image

class Image(metaclass=MetaImage):
pass


# ======== Mock functions ========
Expand Down