Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Serve sanity checks #423

Merged
merged 22 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ jobs:
pip list
shell: bash

- name: Install serve test dependencies
if: matrix.topic == 'serve'
run: |
pip install '.[all]' --pre --upgrade
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
- name: Cache datasets
uses: actions/cache@v2
with:
Expand All @@ -115,7 +120,8 @@ jobs:

- name: Tests
env:
FIFTYONE_DO_NOT_TRACK: true
FLASH_TEST_TOPIC: ${{ matrix.topic }}
FIFTYONE_DO_NOT_TRACK: true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall t be rather just 0/1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it can be. This was done by the fiftyone people for something on their end. I wouldn't want to change it in case something breaks.

run: |
# tox --sitepackages
coverage run --source flash -m pytest flash tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
Expand Down Expand Up @@ -143,3 +149,8 @@ jobs:
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false

- name: Uninstall
run: |
pip uninstall lightning-flash -y
shell: bash
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))
- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390))
- Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401))
- Added a sanity checking feature to flash.serve ([#423](https://github.com/PyTorchLightning/lightning-flash/pull/423))

### Changed

Expand Down
1 change: 1 addition & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from flash.core.trainer import Trainer # noqa: E402

_PACKAGE_ROOT = os.path.dirname(__file__)
ASSETS_ROOT = os.path.join(_PACKAGE_ROOT, "assets")
PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"

Expand Down
Binary file added flash/assets/fish.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 20 additions & 17 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
Expand All @@ -24,7 +24,6 @@
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE

if _FIFTYONE_AVAILABLE:
import fiftyone as fo
from fiftyone.core.labels import Classification, Classifications
else:
Classification, Classifications = None, None
Expand Down Expand Up @@ -83,34 +82,43 @@ def multi_label(self) -> bool:
return self._mutli_label


class Logits(ClassificationSerializer):
class PredsClassificationSerializer(ClassificationSerializer):
"""A :class:`~flash.core.classification.ClassificationSerializer` which gets the
:attr:`~flash.core.data.data_source.DefaultDataKeys.PREDS` from the sample.
"""

def serialize(self, sample: Any) -> Any:
if isinstance(sample, Mapping) and DefaultDataKeys.PREDS in sample:
sample = sample[DefaultDataKeys.PREDS]
if not isinstance(sample, torch.Tensor):
sample = torch.tensor(sample)
return sample


class Logits(PredsClassificationSerializer):
"""A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
return sample.tolist()
return super().serialize(sample).tolist()


class Probabilities(ClassificationSerializer):
class Probabilities(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies a softmax to the model outputs (assumed to be logits) and converts to a
list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
sample = super().serialize(sample)
if self.multi_label:
return torch.sigmoid(sample).tolist()
return torch.softmax(sample, -1).tolist()


class Classes(ClassificationSerializer):
class Classes(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and
converts to a list.
Args:
multi_label: If true, treats outputs as multi label logits.
threshold: The threshold to use for multi_label classification.
"""

Expand All @@ -120,8 +128,7 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5):
self.threshold = threshold

def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
sample = super().serialize(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
Expand All @@ -139,9 +146,7 @@ class Labels(Classes):
Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
multi_label: If true, treats outputs as multi label logits.
threshold: The threshold to use for multi_label classification.
"""

Expand All @@ -153,8 +158,6 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
labels = None

if self._labels is not None:
Expand Down
4 changes: 4 additions & 0 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) ->
data_pipeline_state._initialized = True # TODO: Not sure we need this
return data_pipeline_state

@property
def example_input(self) -> str:
return self._deserializer.example_input

@staticmethod
def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool:
"""
Expand Down
7 changes: 6 additions & 1 deletion flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from abc import ABC, abstractclassmethod, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence

import torch
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -569,6 +569,11 @@ class Deserializer(Properties):
def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor???
raise NotImplementedError

@property
@abstractmethod
def example_input(self) -> str:
pass

def __call__(self, sample: Any) -> Any:
return self.deserialize(sample)

Expand Down
77 changes: 41 additions & 36 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
)
from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.serve import Composition, expose, ModelComponent
from flash.core.serve import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import _SERVE_AVAILABLE


class BenchmarkConvergenceCI(Callback):
Expand Down Expand Up @@ -390,12 +391,18 @@ def build_data_pipeline(
else:
data_source = preprocess.data_source_of_name(data_source)

deserializer = deserializer or getattr(preprocess, "deserializer", None)
if deserializer is None or type(deserializer) == Deserializer:
deserializer = getattr(preprocess, "deserializer", deserializer)

data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, serializer)
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@torch.jit.unused
@property
def is_servable(self) -> bool:
return type(self.build_data_pipeline()._deserializer) != Deserializer

@torch.jit.unused
@property
def data_pipeline(self) -> DataPipeline:
Expand Down Expand Up @@ -592,41 +599,39 @@ def configure_callbacks(self):
if flash._IS_TESTING and torch.cuda.is_available():
return [BenchmarkConvergenceCI()]

def serve(self, host: str = "127.0.0.1", port: int = 8000) -> 'Composition':
from flash.core.serve.flash_components import FlashInputs, FlashOutputs
def run_serve_sanity_check(self):
if not _SERVE_AVAILABLE:
raise ModuleNotFoundError("Please, pip install 'lightning-flash[serve]'")
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")

class FlashServeModelComponent(ModelComponent):
from fastapi.testclient import TestClient

def __init__(self, model):
self.model = model
self.model.eval()
self.data_pipeline = self.model.build_data_pipeline()
self.worker_preprocessor = self.data_pipeline.worker_preprocessor(
RunningStage.PREDICTING, is_serving=True
)
self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING)
self.postprocessor = self.data_pipeline.postprocessor(RunningStage.PREDICTING, is_serving=True)
# todo (tchaton) Remove this hack
self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3
self.device = self.model.device

@expose(
inputs={"inputs": FlashInputs(self.data_pipeline.deserialize_processor())},
outputs={"outputs": FlashOutputs(self.data_pipeline.serialize_processor())},
)
def predict(self, inputs):
with torch.no_grad():
inputs = self.worker_preprocessor(inputs)
if self.extra_arguments:
inputs = self.model.transfer_batch_to_device(inputs, self.device, 0)
else:
inputs = self.model.transfer_batch_to_device(inputs, self.device)
inputs = self.device_preprocessor(inputs)
preds = self.model.predict_step(inputs, 0)
preds = self.postprocessor(preds)
return preds

comp = FlashServeModelComponent(self)
composition = Composition(predict=comp)
from flash.core.serve.flash_components import build_flash_serve_model_component

print("Running serve sanity check")
comp = build_flash_serve_model_component(self)
composition = Composition(predict=comp, TESTING=True, DEBUG=True)
app = composition.serve(host="0.0.0.0", port=8000)

with TestClient(app) as tc:
input_str = self.data_pipeline._deserializer.example_input
body = {"session": "UUID", "payload": {"inputs": {"data": input_str}}}
resp = tc.post("http://0.0.0.0:8000/predict", json=body)
print(f"Sanity check response: {resp.json()}")

def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> 'Composition':
if not _SERVE_AVAILABLE:
raise ModuleNotFoundError("Please, pip install 'lightning-flash[serve]'")
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")

from flash.core.serve.flash_components import build_flash_serve_model_component

if sanity_check:
self.run_serve_sanity_check()

comp = build_flash_serve_model_component(self)
composition = Composition(predict=comp, TESTING=flash._IS_TESTING)
composition.serve(host=host, port=port)
return composition
39 changes: 39 additions & 0 deletions flash/core/serve/flash_components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import inspect
from typing import Any, Callable, Mapping, Optional

import torch
from pytorch_lightning.trainer.states import RunningStage

from flash import Task
from flash.core.data.data_source import DefaultDataKeys
from flash.core.serve import expose, ModelComponent
from flash.core.serve.core import FilePath, GridserveScriptLoader
from flash.core.serve.types.base import BaseType

Expand Down Expand Up @@ -54,3 +57,39 @@ class FlashServeScriptLoader(GridserveScriptLoader):
def __init__(self, location: FilePath):
self.location = location
self.instance = self.model_cls.load_from_checkpoint(location)


def build_flash_serve_model_component(model):

data_pipeline = model.build_data_pipeline()

class FlashServeModelComponent(ModelComponent):

def __init__(self, model):
self.model = model
self.model.eval()
self.data_pipeline = model.build_data_pipeline()
self.worker_preprocessor = self.data_pipeline.worker_preprocessor(RunningStage.PREDICTING, is_serving=True)
self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING)
self.postprocessor = self.data_pipeline.postprocessor(RunningStage.PREDICTING, is_serving=True)
# todo (tchaton) Remove this hack
self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3
self.device = self.model.device

@expose(
inputs={"inputs": FlashInputs(data_pipeline.deserialize_processor())},
outputs={"outputs": FlashOutputs(data_pipeline.serialize_processor())},
)
def predict(self, inputs):
with torch.no_grad():
inputs = self.worker_preprocessor(inputs)
if self.extra_arguments:
inputs = self.model.transfer_batch_to_device(inputs, self.device, 0)
else:
inputs = self.model.transfer_batch_to_device(inputs, self.device)
inputs = self.device_preprocessor(inputs)
preds = self.model.predict_step(inputs, 0)
preds = self.postprocessor(preds)
return preds

return FlashServeModelComponent(model)
12 changes: 10 additions & 2 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import warnings
from argparse import ArgumentParser, Namespace
from functools import wraps
Expand All @@ -29,6 +28,7 @@

import flash
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks
from flash.core.utilities.imports import _SERVE_AVAILABLE


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -72,7 +72,7 @@ def insert_env_defaults(self, *args, **kwargs):
class Trainer(PlTrainer):

@_defaults_from_env_vars
def __init__(self, *args, **kwargs):
def __init__(self, *args, serve_sanity_check: bool = True, **kwargs):
if flash._IS_TESTING:
if torch.cuda.is_available():
kwargs["gpus"] = 1
Expand All @@ -85,6 +85,14 @@ def __init__(self, *args, **kwargs):
kwargs["fast_dev_run"] = True
super().__init__(*args, **kwargs)

self.serve_sanity_check = serve_sanity_check

def run_sanity_check(self, ref_model):
super().run_sanity_check(ref_model)

if self.serve_sanity_check and ref_model.is_servable and _SERVE_AVAILABLE:
ref_model.run_serve_sanity_check()

def fit(
self,
model: LightningModule,
Expand Down
1 change: 0 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""

import importlib
import operator
from importlib.util import find_spec
Expand Down
Loading