Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Serve sanity checks (#423)
Browse files Browse the repository at this point in the history
* Add serve sanity checking

* Add tests

* Fixes

* Update CHANGELOG.md

* Try fix

* Try fix

* Try fix

* Try fix

* Small fix

* Try fix

* Update flash/core/classification.py

Co-authored-by: Justus Schock <[email protected]>

* Updates

* Updates

* Fixes

* Try fix

* Update flash/core/model.py

Co-authored-by: thomas chaton <[email protected]>

* Updates

* Add trainer flag

* Fixes

* Fixes

* Fixes

Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
3 people authored Jun 18, 2021
1 parent 089c4e8 commit b8a0820
Show file tree
Hide file tree
Showing 55 changed files with 466 additions and 283 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ jobs:
pip list
shell: bash

- name: Install serve test dependencies
if: matrix.topic == 'serve'
run: |
pip install '.[all]' --pre --upgrade
- name: Cache datasets
uses: actions/cache@v2
with:
Expand All @@ -115,7 +120,8 @@ jobs:

- name: Tests
env:
FIFTYONE_DO_NOT_TRACK: true
FLASH_TEST_TOPIC: ${{ matrix.topic }}
FIFTYONE_DO_NOT_TRACK: true
run: |
# tox --sitepackages
coverage run --source flash -m pytest flash tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
Expand Down Expand Up @@ -143,3 +149,8 @@ jobs:
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false

- name: Uninstall
run: |
pip uninstall lightning-flash -y
shell: bash
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))
- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390))
- Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401))
- Added a sanity checking feature to flash.serve ([#423](https://github.com/PyTorchLightning/lightning-flash/pull/423))

### Changed

Expand Down
1 change: 1 addition & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from flash.core.trainer import Trainer # noqa: E402

_PACKAGE_ROOT = os.path.dirname(__file__)
ASSETS_ROOT = os.path.join(_PACKAGE_ROOT, "assets")
PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"

Expand Down
Binary file added flash/assets/fish.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 20 additions & 17 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
Expand All @@ -24,7 +24,6 @@
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE

if _FIFTYONE_AVAILABLE:
import fiftyone as fo
from fiftyone.core.labels import Classification, Classifications
else:
Classification, Classifications = None, None
Expand Down Expand Up @@ -83,34 +82,43 @@ def multi_label(self) -> bool:
return self._mutli_label


class Logits(ClassificationSerializer):
class PredsClassificationSerializer(ClassificationSerializer):
"""A :class:`~flash.core.classification.ClassificationSerializer` which gets the
:attr:`~flash.core.data.data_source.DefaultDataKeys.PREDS` from the sample.
"""

def serialize(self, sample: Any) -> Any:
if isinstance(sample, Mapping) and DefaultDataKeys.PREDS in sample:
sample = sample[DefaultDataKeys.PREDS]
if not isinstance(sample, torch.Tensor):
sample = torch.tensor(sample)
return sample


class Logits(PredsClassificationSerializer):
"""A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
return sample.tolist()
return super().serialize(sample).tolist()


class Probabilities(ClassificationSerializer):
class Probabilities(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies a softmax to the model outputs (assumed to be logits) and converts to a
list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
sample = super().serialize(sample)
if self.multi_label:
return torch.sigmoid(sample).tolist()
return torch.softmax(sample, -1).tolist()


class Classes(ClassificationSerializer):
class Classes(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and
converts to a list.
Args:
multi_label: If true, treats outputs as multi label logits.
threshold: The threshold to use for multi_label classification.
"""

Expand All @@ -120,8 +128,7 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5):
self.threshold = threshold

def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
sample = super().serialize(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
Expand All @@ -139,9 +146,7 @@ class Labels(Classes):
Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
multi_label: If true, treats outputs as multi label logits.
threshold: The threshold to use for multi_label classification.
"""

Expand All @@ -153,8 +158,6 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
labels = None

if self._labels is not None:
Expand Down
4 changes: 4 additions & 0 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) ->
data_pipeline_state._initialized = True # TODO: Not sure we need this
return data_pipeline_state

@property
def example_input(self) -> str:
return self._deserializer.example_input

@staticmethod
def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool:
"""
Expand Down
7 changes: 6 additions & 1 deletion flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from abc import ABC, abstractclassmethod, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence

import torch
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -569,6 +569,11 @@ class Deserializer(Properties):
def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor???
raise NotImplementedError

@property
@abstractmethod
def example_input(self) -> str:
pass

def __call__(self, sample: Any) -> Any:
return self.deserialize(sample)

Expand Down
77 changes: 41 additions & 36 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
)
from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.serve import Composition, expose, ModelComponent
from flash.core.serve import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import _SERVE_AVAILABLE


class BenchmarkConvergenceCI(Callback):
Expand Down Expand Up @@ -390,12 +391,18 @@ def build_data_pipeline(
else:
data_source = preprocess.data_source_of_name(data_source)

deserializer = deserializer or getattr(preprocess, "deserializer", None)
if deserializer is None or type(deserializer) == Deserializer:
deserializer = getattr(preprocess, "deserializer", deserializer)

data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, serializer)
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@torch.jit.unused
@property
def is_servable(self) -> bool:
return type(self.build_data_pipeline()._deserializer) != Deserializer

@torch.jit.unused
@property
def data_pipeline(self) -> DataPipeline:
Expand Down Expand Up @@ -592,41 +599,39 @@ def configure_callbacks(self):
if flash._IS_TESTING and torch.cuda.is_available():
return [BenchmarkConvergenceCI()]

def serve(self, host: str = "127.0.0.1", port: int = 8000) -> 'Composition':
from flash.core.serve.flash_components import FlashInputs, FlashOutputs
def run_serve_sanity_check(self):
if not _SERVE_AVAILABLE:
raise ModuleNotFoundError("Please, pip install 'lightning-flash[serve]'")
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")

class FlashServeModelComponent(ModelComponent):
from fastapi.testclient import TestClient

def __init__(self, model):
self.model = model
self.model.eval()
self.data_pipeline = self.model.build_data_pipeline()
self.worker_preprocessor = self.data_pipeline.worker_preprocessor(
RunningStage.PREDICTING, is_serving=True
)
self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING)
self.postprocessor = self.data_pipeline.postprocessor(RunningStage.PREDICTING, is_serving=True)
# todo (tchaton) Remove this hack
self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3
self.device = self.model.device

@expose(
inputs={"inputs": FlashInputs(self.data_pipeline.deserialize_processor())},
outputs={"outputs": FlashOutputs(self.data_pipeline.serialize_processor())},
)
def predict(self, inputs):
with torch.no_grad():
inputs = self.worker_preprocessor(inputs)
if self.extra_arguments:
inputs = self.model.transfer_batch_to_device(inputs, self.device, 0)
else:
inputs = self.model.transfer_batch_to_device(inputs, self.device)
inputs = self.device_preprocessor(inputs)
preds = self.model.predict_step(inputs, 0)
preds = self.postprocessor(preds)
return preds

comp = FlashServeModelComponent(self)
composition = Composition(predict=comp)
from flash.core.serve.flash_components import build_flash_serve_model_component

print("Running serve sanity check")
comp = build_flash_serve_model_component(self)
composition = Composition(predict=comp, TESTING=True, DEBUG=True)
app = composition.serve(host="0.0.0.0", port=8000)

with TestClient(app) as tc:
input_str = self.data_pipeline._deserializer.example_input
body = {"session": "UUID", "payload": {"inputs": {"data": input_str}}}
resp = tc.post("http://0.0.0.0:8000/predict", json=body)
print(f"Sanity check response: {resp.json()}")

def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> 'Composition':
if not _SERVE_AVAILABLE:
raise ModuleNotFoundError("Please, pip install 'lightning-flash[serve]'")
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")

from flash.core.serve.flash_components import build_flash_serve_model_component

if sanity_check:
self.run_serve_sanity_check()

comp = build_flash_serve_model_component(self)
composition = Composition(predict=comp, TESTING=flash._IS_TESTING)
composition.serve(host=host, port=port)
return composition
39 changes: 39 additions & 0 deletions flash/core/serve/flash_components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import inspect
from typing import Any, Callable, Mapping, Optional

import torch
from pytorch_lightning.trainer.states import RunningStage

from flash import Task
from flash.core.data.data_source import DefaultDataKeys
from flash.core.serve import expose, ModelComponent
from flash.core.serve.core import FilePath, GridserveScriptLoader
from flash.core.serve.types.base import BaseType

Expand Down Expand Up @@ -54,3 +57,39 @@ class FlashServeScriptLoader(GridserveScriptLoader):
def __init__(self, location: FilePath):
self.location = location
self.instance = self.model_cls.load_from_checkpoint(location)


def build_flash_serve_model_component(model):

data_pipeline = model.build_data_pipeline()

class FlashServeModelComponent(ModelComponent):

def __init__(self, model):
self.model = model
self.model.eval()
self.data_pipeline = model.build_data_pipeline()
self.worker_preprocessor = self.data_pipeline.worker_preprocessor(RunningStage.PREDICTING, is_serving=True)
self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING)
self.postprocessor = self.data_pipeline.postprocessor(RunningStage.PREDICTING, is_serving=True)
# todo (tchaton) Remove this hack
self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3
self.device = self.model.device

@expose(
inputs={"inputs": FlashInputs(data_pipeline.deserialize_processor())},
outputs={"outputs": FlashOutputs(data_pipeline.serialize_processor())},
)
def predict(self, inputs):
with torch.no_grad():
inputs = self.worker_preprocessor(inputs)
if self.extra_arguments:
inputs = self.model.transfer_batch_to_device(inputs, self.device, 0)
else:
inputs = self.model.transfer_batch_to_device(inputs, self.device)
inputs = self.device_preprocessor(inputs)
preds = self.model.predict_step(inputs, 0)
preds = self.postprocessor(preds)
return preds

return FlashServeModelComponent(model)
12 changes: 10 additions & 2 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import warnings
from argparse import ArgumentParser, Namespace
from functools import wraps
Expand All @@ -29,6 +28,7 @@

import flash
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks
from flash.core.utilities.imports import _SERVE_AVAILABLE


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -72,7 +72,7 @@ def insert_env_defaults(self, *args, **kwargs):
class Trainer(PlTrainer):

@_defaults_from_env_vars
def __init__(self, *args, **kwargs):
def __init__(self, *args, serve_sanity_check: bool = True, **kwargs):
if flash._IS_TESTING:
if torch.cuda.is_available():
kwargs["gpus"] = 1
Expand All @@ -85,6 +85,14 @@ def __init__(self, *args, **kwargs):
kwargs["fast_dev_run"] = True
super().__init__(*args, **kwargs)

self.serve_sanity_check = serve_sanity_check

def run_sanity_check(self, ref_model):
super().run_sanity_check(ref_model)

if self.serve_sanity_check and ref_model.is_servable and _SERVE_AVAILABLE:
ref_model.run_serve_sanity_check()

def fit(
self,
model: LightningModule,
Expand Down
1 change: 0 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""

import importlib
import operator
from importlib.util import find_spec
Expand Down
Loading

0 comments on commit b8a0820

Please sign in to comment.