Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] Add serve (#399)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* add serve

* resolve flake8

* remove annotations

* resolve typing

* add serve tests

* update

* resolve flake8

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* resolve imports

* convert to code-block

* update

* resolve doc

* bypass fixture

* update

* update on comments

* udpate readme

* remove todo
  • Loading branch information
tchaton authored Jun 11, 2021
1 parent 5552a78 commit 4384e4a
Show file tree
Hide file tree
Showing 114 changed files with 12,148 additions and 106 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: 'image_style_transfer'
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: 'serve'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@ kinetics
movie_posters
CameraRGB
CameraSeg
flash_examples/serve/tabular_classification/data
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ predictions = model.predict([
print(predictions)
```

### Serving

`Serve` is a framework agnostic serving engine ! [Learn more](https://lightning-flash.readthedocs.io/en/latest/reference/flash_to_production.html#) and [find examples](flash_examples/serve/generic/boston_prediction/inference_server.py).

```python
from flash.text import TranslationTask

model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")
model.serve()
```

### Finetuning

First, finetune:
Expand Down
Binary file added docs/source/_static/images/data_serving_flow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/inference_server.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/swagger_ui.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
209 changes: 209 additions & 0 deletions docs/source/general/serve.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#####
Serve
#####

.. _serve:

Serve is a library to easily serve models in production.

***********
Terminology
***********

Here are common terms you need to be familiar with:

.. list-table:: Terminology
:widths: 20 80
:header-rows: 1

* - Term
- Definition
* - de-serialization
- Transform data encoded as text into tensors
* - inference function
- A function taking the decoded tensors and forward them through the model to produce predictions.
* - serialization
- Transform the predictions tensors back to a text encoding.
* - :class:`~flash.core.serve.ModelComponent`
- The :class:`~flash.core.serve.ModelComponent` contains the de-serialization, inference and serialization functions.
* - :class:`~flash.core.serve.GridModel`
- The :class:`~flash.core.serve.GridModel` is an helper track the asset file related to a model
* - :class:`~flash.core.serve.Composition`
- The :class:`~flash.core.serve.Composition` defines the computations / endpoints to create & run
* - :func:`~flash.core.serve.decorators.expose`
- The :func:`~flash.core.serve.decorators.expose` function is a python decorator used to
augment the :class:`~flash.core.serve.ModelComponent` inference function with de-serialization, serialization.


*******
Example
*******

In this tutorial, we will serve a Convolutional Neural Network called Resnet18 from the `PyTorchVision library <https://github.com/pytorch/vision>`_ in 3 steps.

The entire tutorial can be found under ``grid-sdk/examples/serve/image_classification``.

Introduction
============


Traditionally, an inference pipeline is made out of 3 steps:

* ``de-serialization``: Transform data encoded as text into tensors.
* ``inference function``: A function taking the decoded tensors and forward them through the model to produce predictions.
* ``serialization``: Transform the predictions tensors back as text.

In this example, we will implement only the inference function as Grid Serve already provides some built-in ``de-serialization`` and ``serialization`` functions with :class:`~flash.core.serve.types.image.Image`


Step 1 - Create a ModelComponent
================================

Inside ``inference_serve.py``,
we will implement a ``ClassificationInference`` class, which overrides :class:`~flash.core.serve.ModelComponent`.

First, we need make the following imports:

.. code-block::
import torch
import torchvision
from flash.core.serve import Composition, GridModel, ModelComponent, expose
from flash.core.serve.types import Image, Label
.. image:: ../_static/images/data_serving_flow.png
:width: 100%
:alt: Data Serving Flow


To implement ``ClassificationInference``, we need to implement a method responsible for ``inference function`` and decorated with the :func:`~flash.core.serve.decorators.expose` function.

The name of the inference method isn't constrained, but we will use ``classify`` as appropriate in this example.

Our classify function will take a tensor image, apply some normalization on it, and forward it through the model.

.. code-block::
def classify(img):
img = img.float() / 255
mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
img = (img - mean) / std
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return out.argmax()
The :func:`~flash.core.serve.decorators.expose` is a python decorator extending the decorated function with the ``de-serialization``, ``serialization`` steps.

.. note:: Grid Serve was designed this way to enable several models to be chained together by removing the decorator.

The :func:`~flash.core.serve.decorators.expose` function takes 2 arguments:

* ``inputs``: Dictionary mapping the decorated function inputs to :class:`~flash.core.serve.types.base.BaseType` objects.
* ``outputs``: Dictionary mapping the decorated function outputs to :class:`~flash.core.serve.types.base.BaseType` objects.

A :class:`~flash.core.serve.types.base.BaseType` is a python `dataclass <https://docs.python.org/3/library/dataclasses.html>`_
which implements a ``serialize`` and ``deserialize`` function.

.. note:: Grid Serve has already several :class:`~flash.core.serve.types.base.BaseType` built-in such as :class:`~flash.core.serve.types.image.Image` or :class:`~flash.core.serve.types.text.Text`.

.. code-bloc image_classification
class ClassificationInference(ModelComponent):
def __init__(self, model: GridModel):
self.model = model
@expose(
inputs={"img": Image()},
outputs={"prediction": Label(path="imagenet_labels.txt")},
)
def classify(self, img):
img = img.float() / 255
mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
img = (img - mean) / std
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return out.argmax()
Step 2 - Create a scripted Model
================================

Using the `PyTorchVision library <https://github.com/pytorch/vision>`_, we create a ``resnet18`` and use torch.jit.script to script the model.


.. note:: TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.

.. code-block::
model = torchvision.models.resnet18(pretrained=True).eval()
torch.jit.script(model).save("resnet.pt")
Step 3 - Serve the model
========================

The :class:`~flash.core.serve.GridModel` takes as argument the path to the TorchScripted model and then will be passed to our ``ClassificationInference`` class.

The ``ClassificationInference`` instance will be passed as argument to a :class:`~flash.core.serve.Composition` class.

Once the :class:`~flash.core.serve.Composition` class is instantiated, just call its :func:`~flash.core.serve.Composition.serve` method.


.. code-block::
resnet = GridModel("resnet.pt")
comp = ClassificationInference(resnet)
composition = Composition(classification=comp)
composition.serve()
Launching the server.
=====================

In Terminal 1
^^^^^^^^^^^^^^

Just run:

.. code-block::
python inference_server.py
And you should see this in your terminal

.. image:: ../_static/images/inference_server.png
:width: 100%
:alt: Data Serving Flow


You should also see an Swagger UI already built for you at ``http://127.0.0.1:8000/docs``

.. image:: ../_static/images/swagger_ui.png
:width: 100%
:alt: Data Serving Flow


In Terminal 2
^^^^^^^^^^^^^^

Run this script from another terminal:

.. code-block::
import base64
from pathlib import Path
import requests
with Path("fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"img": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
# {'session': 'UUID', 'result': {'prediction': 'goldfish, Carassius auratus'}}
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Lightning Flash
installation
custom_task
reference/flash_to_pl
reference/flash_to_production

.. toctree::
:maxdepth: 1
Expand All @@ -40,6 +41,7 @@ Lightning Flash
general/data
general/callback
general/registry
general/serve


.. toctree::
Expand Down
20 changes: 20 additions & 0 deletions docs/source/reference/flash_to_production.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
########################
From Flash to Production
########################

Flash makes it simple to deploy models in production.

Server Side
^^^^^^^^^^^

.. literalinclude:: ../../../flash_examples/serve/segmentic_segmentation/inference_server.py
:language: python
:lines: 14-


Client Side
^^^^^^^^^^^

.. literalinclude:: ../../../flash_examples/serve/segmentic_segmentation/client.py
:language: python
:lines: 14-
4 changes: 4 additions & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"

if _IS_TESTING:
from pytorch_lightning import seed_everything
seed_everything(42)

__all__ = [
"DataSource",
"DataModule",
Expand Down
Loading

0 comments on commit 4384e4a

Please sign in to comment.