Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Serve sanity checks #423

Merged
merged 22 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ jobs:
pip list
shell: bash

- name: Install serve test dependencies
if: matrix.topic == 'serve'
run: |
pip install '.[all]' --pre --upgrade
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

- name: Cache datasets
uses: actions/cache@v2
with:
Expand All @@ -115,7 +120,8 @@ jobs:

- name: Tests
env:
FIFTYONE_DO_NOT_TRACK: true
FLASH_TEST_TOPIC: ${{ matrix.topic }}
FIFTYONE_DO_NOT_TRACK: true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall t be rather just 0/1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it can be. This was done by the fiftyone people for something on their end. I wouldn't want to change it in case something breaks.

run: |
# tox --sitepackages
coverage run --source flash -m pytest flash tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
Expand Down Expand Up @@ -143,3 +149,8 @@ jobs:
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false

- name: Uninstall
run: |
pip uninstall lightning-flash -y
shell: bash
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))
- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390))
- Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401))
- Added a sanity checking feature to flash.serve ([#423](https://github.com/PyTorchLightning/lightning-flash/pull/423))

### Changed

Expand Down
1 change: 1 addition & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from flash.core.trainer import Trainer # noqa: E402

_PACKAGE_ROOT = os.path.dirname(__file__)
ASSETS_ROOT = os.path.join(_PACKAGE_ROOT, "assets")
PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"

Expand Down
Binary file added flash/assets/fish.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 20 additions & 17 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
Expand All @@ -24,7 +24,6 @@
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE

if _FIFTYONE_AVAILABLE:
import fiftyone as fo
from fiftyone.core.labels import Classification, Classifications
else:
Classification, Classifications = None, None
Expand Down Expand Up @@ -83,34 +82,43 @@ def multi_label(self) -> bool:
return self._mutli_label


class Logits(ClassificationSerializer):
class PredsClassificationSerializer(ClassificationSerializer):
"""A :class:`~flash.core.classification.ClassificationSerializer` which gets the
:attr:`~flash.core.data.data_source.DefaultDataKeys.PREDS` from the sample.
"""

def serialize(self, sample: Any) -> Any:
if isinstance(sample, Mapping) and DefaultDataKeys.PREDS in sample:
sample = sample[DefaultDataKeys.PREDS]
if not isinstance(sample, torch.Tensor):
sample = torch.tensor(sample)
return sample


class Logits(PredsClassificationSerializer):
"""A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
return sample.tolist()
return super().serialize(sample).tolist()


class Probabilities(ClassificationSerializer):
class Probabilities(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies a softmax to the model outputs (assumed to be logits) and converts to a
list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
sample = super().serialize(sample)
if self.multi_label:
return torch.sigmoid(sample).tolist()
return torch.softmax(sample, -1).tolist()


class Classes(ClassificationSerializer):
class Classes(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and
converts to a list.

Args:
multi_label: If true, treats outputs as multi label logits.

threshold: The threshold to use for multi_label classification.
"""

Expand All @@ -120,8 +128,7 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5):
self.threshold = threshold

def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
sample = super().serialize(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
Expand All @@ -139,9 +146,7 @@ class Labels(Classes):
Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.

multi_label: If true, treats outputs as multi label logits.

threshold: The threshold to use for multi_label classification.
"""

Expand All @@ -153,8 +158,6 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
labels = None

if self._labels is not None:
Expand Down
4 changes: 4 additions & 0 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) ->
data_pipeline_state._initialized = True # TODO: Not sure we need this
return data_pipeline_state

@property
def example_input(self) -> str:
return self._deserializer.example_input

@staticmethod
def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool:
"""
Expand Down
7 changes: 6 additions & 1 deletion flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from abc import ABC, abstractclassmethod, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence

import torch
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -569,6 +569,11 @@ class Deserializer(Properties):
def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor???
raise NotImplementedError

@property
@abstractmethod
def example_input(self) -> str:
pass

def __call__(self, sample: Any) -> Any:
return self.deserialize(sample)

Expand Down
19 changes: 17 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import functools
import inspect
import os
from copy import deepcopy
from importlib import import_module
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
Expand Down Expand Up @@ -592,7 +593,9 @@ def configure_callbacks(self):
if flash._IS_TESTING and torch.cuda.is_available():
return [BenchmarkConvergenceCI()]

def serve(self, host: str = "127.0.0.1", port: int = 8000) -> 'Composition':
def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> 'Composition':
from fastapi.testclient import TestClient

from flash.core.serve.flash_components import FlashInputs, FlashOutputs

class FlashServeModelComponent(ModelComponent):
Expand Down Expand Up @@ -626,7 +629,19 @@ def predict(self, inputs):
preds = self.postprocessor(preds)
return preds

if sanity_check:
print("Running sanity check")
comp = FlashServeModelComponent(self)
composition = Composition(predict=comp, TESTING=True, DEBUG=True)
app = composition.serve(host="0.0.0.0", port=8000)

with TestClient(app) as tc:
input_str = self.data_pipeline._deserializer.example_input
body = {"session": "UUID", "payload": {"inputs": {"data": input_str}}}
resp = tc.post("http://0.0.0.0:8000/predict", json=body)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
print(f"Sanity check response: {resp.json()}")

comp = FlashServeModelComponent(self)
composition = Composition(predict=comp)
composition = Composition(predict=comp, TESTING=os.environ["FLASH_TESTING"] == "1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, some variables are 0/1 other you write s true... lets be consistent

ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
composition.serve(host=host, port=port)
return composition
1 change: 0 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""

import importlib
import operator
from importlib.util import find_spec
Expand Down
34 changes: 9 additions & 25 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -25,43 +23,29 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE
from flash.image.classification.transforms import default_transforms, train_default_transforms
from flash.image.data import ImageFiftyOneDataSource, ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource
from flash.image.data import (
ImageDeserializer,
ImageFiftyOneDataSource,
ImageNumpyDataSource,
ImagePathsDataSource,
ImageTensorDataSource,
)

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
plt = None

if _TORCHVISION_AVAILABLE:
import torchvision

if _IMAGE_AVAILABLE:
from PIL import Image
from PIL import Image as PILImage
else:

class Image:
Image = None


class ImageClassificationDeserializer(Deserializer):

def __init__(self):

self.to_tensor = torchvision.transforms.ToTensor()

def deserialize(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
return {
DefaultDataKeys.INPUT: img,
}


class ImageClassificationPreprocess(Preprocess):

def __init__(
Expand All @@ -88,7 +72,7 @@ def __init__(
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSORS: ImageTensorDataSource(),
},
deserializer=deserializer or ImageClassificationDeserializer(),
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
)

Expand Down
36 changes: 35 additions & 1 deletion flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Optional

import torch

import flash
from flash.core.data.data_source import (
DefaultDataKeys,
FiftyOneDataSource,
NumpyDataSource,
PathsDataSource,
TensorDataSource,
)
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.data.process import Deserializer
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
from torchvision.transforms.functional import to_pil_image
else:
IMG_EXTENSIONS = []

if _IMAGE_AVAILABLE:
from PIL import Image as PILImage
else:

class Image:
Image = None


class ImageDeserializer(Deserializer):

def __init__(self):
super().__init__()
self.to_tensor = torchvision.transforms.ToTensor()

def deserialize(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
return {
DefaultDataKeys.INPUT: img,
}

@property
def example_input(self) -> str:
with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
return base64.b64encode(f.read()).decode("UTF-8")


class ImagePathsDataSource(PathsDataSource):

Expand Down
17 changes: 6 additions & 11 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE
from flash.image.data import ImageDeserializer
from flash.image.segmentation.serialization import SegmentationLabels
from flash.image.segmentation.transforms import default_transforms, train_default_transforms

Expand Down Expand Up @@ -215,19 +216,13 @@ def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
return sample


class SemanticSegmentationDeserializer(Deserializer):

def __init__(self):

self.to_tensor = torchvision.transforms.ToTensor()
class SemanticSegmentationDeserializer(ImageDeserializer):

def deserialize(self, data: str) -> torch.Tensor:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
img = self.to_tensor(img)
return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: {"size": img.shape}}
result = super().deserialize(data)
result[DefaultDataKeys.INPUT] = self.to_tensor(result[DefaultDataKeys.INPUT])
result[DefaultDataKeys.METADATA] = {"size": result[DefaultDataKeys.INPUT].shape}
return result


class SemanticSegmentationPreprocess(Preprocess):
Expand Down
Loading