Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'bugfix/audio_numpy_loading' of https://github.com/PyTor…
Browse files Browse the repository at this point in the history
…chLightning/lightning-flash into bugfix/audio_numpy_loading
  • Loading branch information
ethanwharris committed Sep 7, 2021
2 parents c3e4f4e + 2cc719a commit 9fbfab5
Show file tree
Hide file tree
Showing 27 changed files with 95 additions and 88 deletions.
6 changes: 3 additions & 3 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from flash.core.data.process import Deserializer, Postprocess, Preprocess
from flash.core.data.properties import ProcessState
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires_extras
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires

if _AUDIO_AVAILABLE:
import librosa
Expand Down Expand Up @@ -155,7 +155,7 @@ def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:


class SpeechRecognitionPreprocess(Preprocess):
@requires_extras("audio")
@requires("audio")
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
Expand Down Expand Up @@ -197,7 +197,7 @@ class SpeechRecognitionBackboneState(ProcessState):


class SpeechRecognitionPostprocess(Postprocess):
@requires_extras("audio")
@requires("audio")
def __init__(self):
super().__init__()

Expand Down
6 changes: 3 additions & 3 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import nn

from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires_extras
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires

if _ICEVISION_AVAILABLE:
from icevision.core import tasks
Expand Down Expand Up @@ -206,15 +206,15 @@ def forward(self, x):
return from_icevision_record(record)


@requires_extras("image")
@requires(["image", "icevision"])
def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
"""The default transforms from IceVision."""
return {
"pre_tensor_transform": IceVisionTransformAdapter([*A.resize_and_pad(image_size), A.Normalize()]),
}


@requires_extras("image")
@requires(["image", "icevision"])
def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
"""The default augmentations from IceVision."""
return {
Expand Down
12 changes: 6 additions & 6 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.serve import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import requires_extras
from flash.core.utilities.imports import requires


class ModuleWrapperBase:
Expand Down Expand Up @@ -258,11 +258,11 @@ class CheckDependenciesMeta(ABCMeta):
def __new__(mcs, *args, **kwargs):
result = ABCMeta.__new__(mcs, *args, **kwargs)
if result.required_extras is not None:
result.__init__ = requires_extras(result.required_extras)(result.__init__)
result.__init__ = requires(result.required_extras)(result.__init__)
load_from_checkpoint = getattr(result, "load_from_checkpoint", None)
if load_from_checkpoint is not None:
result.load_from_checkpoint = classmethod(
requires_extras(result.required_extras)(result.load_from_checkpoint.__func__)
requires(result.required_extras)(result.load_from_checkpoint.__func__)
)
return result

Expand All @@ -282,7 +282,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check

schedulers: FlashRegistry = _SCHEDULERS_REGISTRY

required_extras: Optional[str] = None
required_extras: Optional[Union[str, List[str]]] = None

def __init__(
self,
Expand Down Expand Up @@ -826,7 +826,7 @@ def configure_callbacks(self):
if flash._IS_TESTING and torch.cuda.is_available():
return [BenchmarkConvergenceCI()]

@requires_extras("serve")
@requires("serve")
def run_serve_sanity_check(self):
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")
Expand All @@ -846,7 +846,7 @@ def run_serve_sanity_check(self):
resp = tc.post("http://0.0.0.0:8000/predict", json=body)
print(f"Sanity check response: {resp.json()}")

@requires_extras("serve")
@requires("serve")
def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> "Composition":
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")
Expand Down
4 changes: 2 additions & 2 deletions flash/core/serve/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from flash.core.serve.core import ParameterContainer, Servable
from flash.core.serve.decorators import BoundMeta, UnboundMeta
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_AVAILABLE, requires_extras
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_AVAILABLE, requires

if _CYTOOLZ_AVAILABLE:
from cytoolz import first, isiterable, valfilter
Expand Down Expand Up @@ -145,7 +145,7 @@ def _validate_config_args(config: Optional[Dict[str, Union[str, int, float, byte
class FlashServeMeta(type):
"""We keep a mapping of externally used names to classes."""

@requires_extras("serve")
@requires("serve")
def __new__(cls, name, bases, namespace):
# create new instance of cls in order to apply any @expose class decorations.
_tmp_cls = super().__new__(cls, name, bases, namespace)
Expand Down
4 changes: 2 additions & 2 deletions flash/core/serve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from flash.core.serve.types.base import BaseType
from flash.core.serve.utils import download_file
from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, requires_extras
from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, requires

if _PYDANTIC_AVAILABLE:
from pydantic import FilePath, HttpUrl, parse_obj_as, ValidationError
Expand Down Expand Up @@ -102,7 +102,7 @@ class Servable:
* How to handle ``__init__`` args not recorded in hparams of ``pl.LightningModule``
"""

@requires_extras("serve")
@requires("serve")
def __init__(
self,
*args: ServableValidArgs_T,
Expand Down
43 changes: 21 additions & 22 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import operator
import types
from importlib.util import find_spec
from typing import Callable, List, Union
from typing import List, Union
from warnings import warn

from pkg_resources import DistributionNotFound
Expand Down Expand Up @@ -142,8 +142,6 @@ class Image(metaclass=MetaImage):
_KORNIA_AVAILABLE,
_PYSTICHE_AVAILABLE,
_SEGMENTATION_MODELS_AVAILABLE,
_ICEVISION_AVAILABLE,
_ICEDATA_AVAILABLE,
]
)
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
Expand All @@ -163,22 +161,33 @@ class Image(metaclass=MetaImage):
}


def _requires(
module_paths: Union[str, List],
module_available: Callable[[str], bool],
formatter: Callable[[List[str]], str],
):
def requires(module_paths: Union[str, List]):

if not isinstance(module_paths, list):
module_paths = [module_paths]

def decorator(func):
if not all(module_available(module_path) for module_path in module_paths):
available = True
extras = []
modules = []
for module_path in module_paths:
if module_path in _EXTRAS_AVAILABLE:
extras.append(module_path)
if not _EXTRAS_AVAILABLE[module_path]:
available = False
else:
modules.append(module_path)
if not _module_available(module_path):
available = False

if not available:
modules = [f"'{module}'" for module in modules]
modules.append(f"'lightning-flash[{','.join(extras)}]'")

@functools.wraps(func)
def wrapper(*args, **kwargs):
raise ModuleNotFoundError(
f"Required dependencies not available. Please run: pip install {formatter(module_paths)}"
f"Required dependencies not available. Please run: pip install {' '.join(modules)}"
)

return wrapper
Expand All @@ -188,18 +197,8 @@ def wrapper(*args, **kwargs):
return decorator


def requires(module_paths: Union[str, List]):
return _requires(module_paths, _module_available, lambda module_paths: " ".join(module_paths))


def requires_extras(extras: Union[str, List]):
return _requires(
extras, lambda extras: _EXTRAS_AVAILABLE[extras], lambda extras: f"'lightning-flash[{','.join(extras)}]'"
)


def example_requires(extras: Union[str, List[str]]):
return requires_extras(extras)(lambda: None)()
def example_requires(module_paths: Union[str, List[str]]):
return requires(module_paths)(lambda: None)()


def lazy_import(module_name, callback=None):
Expand Down
4 changes: 2 additions & 2 deletions flash/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
from torch.utils.data import Dataset

from flash.core.data.data_source import DatasetDataSource
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires_extras
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires

if _GRAPH_AVAILABLE:
from torch_geometric.data import Data
from torch_geometric.data import Dataset as TorchGeometricDataset


class GraphDatasetDataSource(DatasetDataSource):
@requires_extras("graph")
@requires("graph")
def load_data(self, data: Dataset, dataset: Any = None) -> Dataset:
data = super().load_data(data, dataset=dataset)
if not self.predicting:
Expand Down
6 changes: 3 additions & 3 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LoaderDataFrameDataSource
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires, requires_extras
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires
from flash.image.classification.transforms import default_transforms, train_default_transforms
from flash.image.data import (
image_loader,
Expand All @@ -45,7 +45,7 @@ class ImageClassificationDataFrameDataSource(LoaderDataFrameDataSource):
def __init__(self):
super().__init__(image_loader)

@requires_extras("image")
@requires("image")
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample = super().load_sample(sample, dataset)
w, h = sample[DefaultDataKeys.INPUT].size # WxH
Expand Down Expand Up @@ -315,7 +315,7 @@ class MatplotlibVisualization(BaseVisualization):
block_viz_window: bool = True # parameter to allow user to block visualisation windows

@staticmethod
@requires_extras("image")
@requires("image")
def _to_numpy(img: Union[np.ndarray, torch.Tensor, Image.Image]) -> np.ndarray:
out: np.ndarray
if isinstance(img, np.ndarray):
Expand Down
6 changes: 3 additions & 3 deletions flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
TensorDataSource,
)
from flash.core.data.process import Deserializer
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires_extras
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
Expand All @@ -55,7 +55,7 @@ def image_loader(filepath: str):


class ImageDeserializer(Deserializer):
@requires_extras("image")
@requires("image")
def deserialize(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
Expand All @@ -75,7 +75,7 @@ class ImagePathsDataSource(PathsDataSource):
def __init__(self):
super().__init__(loader=image_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS)

@requires_extras("image")
@requires("image")
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample = super().load_sample(sample, dataset)
w, h = sample[DefaultDataKeys.INPUT].size # WxH
Expand Down
2 changes: 1 addition & 1 deletion flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ObjectDetector(AdapterTask):

heads: FlashRegistry = OBJECT_DETECTION_HEADS

required_extras: str = "image"
required_extras: List[str] = ["image", "icevision", "effdet"]

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions flash/image/instance_segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Callable, Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires
from flash.image import InstanceSegmentation, InstanceSegmentationData

if _ICEDATA_AVAILABLE:
Expand All @@ -24,7 +24,7 @@
__all__ = ["instance_segmentation"]


@requires_extras("image")
@requires(["image", "icedata"])
def from_pets(
val_split: float = 0.1,
batch_size: int = 4,
Expand Down
2 changes: 1 addition & 1 deletion flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class InstanceSegmentation(AdapterTask):

heads: FlashRegistry = INSTANCE_SEGMENTATION_HEADS

required_extras: str = "image"
required_extras: List[str] = ["image", "icevision"]

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions flash/image/keypoint_detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Callable, Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires
from flash.image import KeypointDetectionData, KeypointDetector

if _ICEDATA_AVAILABLE:
Expand All @@ -23,7 +23,7 @@
__all__ = ["keypoint_detection"]


@requires_extras("image")
@requires("image")
def from_biwi(
val_split: float = 0.1,
batch_size: int = 4,
Expand Down
2 changes: 1 addition & 1 deletion flash/image/keypoint_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class KeypointDetector(AdapterTask):

heads: FlashRegistry = KEYPOINT_DETECTION_HEADS

required_extras: str = "image"
required_extras: List[str] = ["image", "icevision"]

def __init__(
self,
Expand Down
3 changes: 1 addition & 2 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
Image,
lazy_import,
requires,
requires_extras,
)
from flash.image.data import ImageDeserializer, IMG_EXTENSIONS
from flash.image.segmentation.serialization import SegmentationLabels
Expand Down Expand Up @@ -459,7 +458,7 @@ def __init__(self, labels_map: Dict[int, Tuple[int, int, int]]):
self.labels_map: Dict[int, Tuple[int, int, int]] = labels_map

@staticmethod
@requires_extras("image")
@requires("image")
def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
out: np.ndarray
if isinstance(img, Image.Image):
Expand Down
3 changes: 1 addition & 2 deletions flash/image/segmentation/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
_MATPLOTLIB_AVAILABLE,
lazy_import,
requires,
requires_extras,
)

Segmentation = None
Expand Down Expand Up @@ -56,7 +55,7 @@ class SegmentationLabels(Serializer):
visualize: Wether to visualize the image labels.
"""

@requires_extras("image")
@requires("image")
def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False):
super().__init__()
self.labels_map = labels_map
Expand Down
Loading

0 comments on commit 9fbfab5

Please sign in to comment.