Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add serve #399

Merged
merged 33 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: 'image_style_transfer'
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: 'serve'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@ kinetics
movie_posters
CameraRGB
CameraSeg
flash_examples/serve/tabular_classification/data
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ predictions = model.predict([
print(predictions)
```

### Serving

`Serve` is a framework agnostic serving engine ! [Learn more](https://lightning-flash.readthedocs.io/en/latest/reference/flash_to_production.html#) and [find examples](flash_examples/serve/generic/boston_prediction/inference_server.py).

```python
from flash.text import TranslationTask

model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")
model.serve()
```

### Finetuning

First, finetune:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/inference_server.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/swagger_ui.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
209 changes: 209 additions & 0 deletions docs/source/general/serve.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#####
Serve
#####

.. _serve:

Serve is a library to easily serve models in production.

***********
Terminology
***********

Here are common terms you need to be familiar with:

.. list-table:: Terminology
:widths: 20 80
:header-rows: 1

* - Term
- Definition
* - de-serialization
- Transform data encoded as text into tensors
* - inference function
- A function taking the decoded tensors and forward them through the model to produce predictions.
* - serialization
- Transform the predictions tensors back to a text encoding.
* - :class:`~flash.core.serve.ModelComponent`
- The :class:`~flash.core.serve.ModelComponent` contains the de-serialization, inference and serialization functions.
* - :class:`~flash.core.serve.GridModel`
- The :class:`~flash.core.serve.GridModel` is an helper track the asset file related to a model
* - :class:`~flash.core.serve.Composition`
- The :class:`~flash.core.serve.Composition` defines the computations / endpoints to create & run
* - :func:`~flash.core.serve.decorators.expose`
- The :func:`~flash.core.serve.decorators.expose` function is a python decorator used to
augment the :class:`~flash.core.serve.ModelComponent` inference function with de-serialization, serialization.


*******
Example
*******

In this tutorial, we will serve a Convolutional Neural Network called Resnet18 from the `PyTorchVision library <https://github.com/pytorch/vision>`_ in 3 steps.

The entire tutorial can be found under ``grid-sdk/examples/serve/image_classification``.

Introduction
============


Traditionally, an inference pipeline is made out of 3 steps:

* ``de-serialization``: Transform data encoded as text into tensors.
* ``inference function``: A function taking the decoded tensors and forward them through the model to produce predictions.
* ``serialization``: Transform the predictions tensors back as text.

In this example, we will implement only the inference function as Grid Serve already provides some built-in ``de-serialization`` and ``serialization`` functions with :class:`~flash.core.serve.types.image.Image`


Step 1 - Create a ModelComponent
================================

Inside ``inference_serve.py``,
we will implement a ``ClassificationInference`` class, which overrides :class:`~flash.core.serve.ModelComponent`.

First, we need make the following imports:

.. code-block::

import torch
import torchvision

from flash.core.serve import Composition, GridModel, ModelComponent, expose
from flash.core.serve.types import Image, Label


.. image:: ../_static/images/data_serving_flow.png
:width: 100%
:alt: Data Serving Flow


To implement ``ClassificationInference``, we need to implement a method responsible for ``inference function`` and decorated with the :func:`~flash.core.serve.decorators.expose` function.

The name of the inference method isn't constrained, but we will use ``classify`` as appropriate in this example.

Our classify function will take a tensor image, apply some normalization on it, and forward it through the model.

.. code-block::

def classify(img):
img = img.float() / 255
mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
img = (img - mean) / std
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return out.argmax()


The :func:`~flash.core.serve.decorators.expose` is a python decorator extending the decorated function with the ``de-serialization``, ``serialization`` steps.

.. note:: Grid Serve was designed this way to enable several models to be chained together by removing the decorator.

The :func:`~flash.core.serve.decorators.expose` function takes 2 arguments:

* ``inputs``: Dictionary mapping the decorated function inputs to :class:`~flash.core.serve.types.base.BaseType` objects.
* ``outputs``: Dictionary mapping the decorated function outputs to :class:`~flash.core.serve.types.base.BaseType` objects.

A :class:`~flash.core.serve.types.base.BaseType` is a python `dataclass <https://docs.python.org/3/library/dataclasses.html>`_
which implements a ``serialize`` and ``deserialize`` function.

.. note:: Grid Serve has already several :class:`~flash.core.serve.types.base.BaseType` built-in such as :class:`~flash.core.serve.types.image.Image` or :class:`~flash.core.serve.types.text.Text`.

.. code-bloc image_classification


class ClassificationInference(ModelComponent):
def __init__(self, model: GridModel):
self.model = model

@expose(
inputs={"img": Image()},
outputs={"prediction": Label(path="imagenet_labels.txt")},
)
def classify(self, img):
img = img.float() / 255
mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
img = (img - mean) / std
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return out.argmax()


Step 2 - Create a scripted Model
================================

Using the `PyTorchVision library <https://github.com/pytorch/vision>`_, we create a ``resnet18`` and use torch.jit.script to script the model.


.. note:: TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.

.. code-block::

model = torchvision.models.resnet18(pretrained=True).eval()
torch.jit.script(model).save("resnet.pt")

Step 3 - Serve the model
========================

The :class:`~flash.core.serve.GridModel` takes as argument the path to the TorchScripted model and then will be passed to our ``ClassificationInference`` class.

The ``ClassificationInference`` instance will be passed as argument to a :class:`~flash.core.serve.Composition` class.

Once the :class:`~flash.core.serve.Composition` class is instantiated, just call its :func:`~flash.core.serve.Composition.serve` method.


.. code-block::

resnet = GridModel("resnet.pt")
comp = ClassificationInference(resnet)
composition = Composition(classification=comp)
composition.serve()


Launching the server.
=====================

In Terminal 1
^^^^^^^^^^^^^^

Just run:

.. code-block::

python inference_server.py

And you should see this in your terminal

.. image:: ../_static/images/inference_server.png
:width: 100%
:alt: Data Serving Flow


You should also see an Swagger UI already built for you at ``http://127.0.0.1:8000/docs``

.. image:: ../_static/images/swagger_ui.png
:width: 100%
:alt: Data Serving Flow


In Terminal 2
^^^^^^^^^^^^^^

Run this script from another terminal:

.. code-block::

import base64
from pathlib import Path

import requests

with Path("fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"img": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
# {'session': 'UUID', 'result': {'prediction': 'goldfish, Carassius auratus'}}
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Lightning Flash
installation
custom_task
reference/flash_to_pl
reference/flash_to_production

.. toctree::
:maxdepth: 1
Expand All @@ -40,6 +41,7 @@ Lightning Flash
general/data
general/callback
general/registry
general/serve


.. toctree::
Expand Down
20 changes: 20 additions & 0 deletions docs/source/reference/flash_to_production.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
########################
From Flash to Production
########################

Flash makes it simple to deploy models in production.

Server Side
^^^^^^^^^^^

.. literalinclude:: ../../../flash_examples/serve/segmentic_segmentation/inference_server.py
:language: python
:lines: 14-


Client Side
^^^^^^^^^^^

.. literalinclude:: ../../../flash_examples/serve/segmentic_segmentation/client.py
:language: python
:lines: 14-
4 changes: 4 additions & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"

if _IS_TESTING:
from pytorch_lightning import seed_everything
seed_everything(42)

__all__ = [
"DataSource",
"DataModule",
Expand Down
Loading