Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Flash serve bug fixes (#422)
Browse files Browse the repository at this point in the history
* Flash serve bug fixes

* Fixes
  • Loading branch information
ethanwharris authored Jun 17, 2021
1 parent bbb6f9f commit 089c4e8
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 22 deletions.
4 changes: 2 additions & 2 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ We get the following output:
.. testcode::
:hide:

assert all([prediction in [0, 1] for prediction in predictions])
assert all([prediction in ["positive", "negative"] for prediction in predictions])

.. code-block::
[1, 1, 0]
["negative", "negative", "positive"]
-------

Expand Down
6 changes: 1 addition & 5 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,7 @@ def forward(self, batch: Sequence[Any]):
self.save_fn(pred)
else:
self.save_fn(final_preds)
else:
# todo (tchaton): Debug the serializer not iterating over a list.
if self.is_serving and isinstance(final_preds, list) and len(final_preds) == 1:
return final_preds[0]
return final_preds
return final_preds

def __str__(self) -> str:
return (
Expand Down
23 changes: 14 additions & 9 deletions flash/core/serve/flash_components.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import inspect
from pathlib import Path
from typing import Any, Callable, Optional, Type
from typing import Any, Callable, Mapping, Optional

import torch
from pytorch_lightning.trainer.states import RunningStage

from flash import Task
from flash.core.serve import Composition, expose, GridModel, ModelComponent
from flash.core.serve.core import FilePath, GridModelValidArgs_T, GridserveScriptLoader
from flash.core.data.data_source import DefaultDataKeys
from flash.core.serve.core import FilePath, GridserveScriptLoader
from flash.core.serve.types.base import BaseType


Expand All @@ -34,9 +31,17 @@ def __init__(
):
self._serializer = serializer

def serialize(self, output) -> Any: # pragma: no cover
result = self._serializer(output)
return result
def serialize(self, outputs) -> Any: # pragma: no cover
results = []
if isinstance(outputs, list) or isinstance(outputs, torch.Tensor):
for output in outputs:
result = self._serializer(output)
if isinstance(result, Mapping):
result = result[DefaultDataKeys.PREDS]
results.append(result)
if len(results) == 1:
return results[0]
return results

def deserialize(self, data: str) -> Any: # pragma: no cover
return None
Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def deserialize(self, data: str) -> torch.Tensor:
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
img = self.to_tensor(img)
return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: img.shape}
return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: {"size": img.shape}}


class SemanticSegmentationPreprocess(Preprocess):
Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
from torch import nn
Expand Down
2 changes: 0 additions & 2 deletions flash_examples/predict/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

model.serializer = Labels()

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.image import SemanticSegmentation
from flash.image.segmentation.serialization import SegmentationLabels

model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=False)
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.classification import Labels
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")
model.serializer = Labels(['Did not survive', 'Survived'])
model.serve()
4 changes: 2 additions & 2 deletions requirements/serve.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ fastapi>=0.65.2,<0.66.0
# to have full feature control of fastapi, manually install optional
# dependencies rather than installing fastapi[all]
# https://fastapi.tiangolo.com/#optional-dependencies
pydantic>=1.6.0,<2.0.0
starlette>=0.14.0
pydantic>1.8.1,<2.0.0
starlette==0.14.2
uvicorn[standard]>=0.12.0,<0.14.0
aiofiles
jinja2
Expand Down

0 comments on commit 089c4e8

Please sign in to comment.