Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into issue_430
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Jun 21, 2021
2 parents d26ef67 + 40fc879 commit d92090d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/PyCQA/isort
rev: 5.8.0
rev: 5.9.1
hooks:
- id: isort

Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ print(predictions)

### Serving

`Serve` is a framework agnostic serving engine ! [Learn more](https://lightning-flash.readthedocs.io/en/latest/general/serve.html#) and [find examples](flash_examples/serve/generic/boston_prediction/inference_server.py).
`Serve` is a framework agnostic serving engine ! [Learn more](https://lightning-flash.readthedocs.io/en/latest/general/serve.html#) and [check out our examples](flash_examples/serve).

```python
from flash.text import TranslationTask
from flash.text import TextClassifier

model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
model.serve()
```

Expand Down
8 changes: 5 additions & 3 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@ def __new__(mcs, *args, **kwargs):
result = ABCMeta.__new__(mcs, *args, **kwargs)
if result.required_extras is not None:
result.__init__ = _requires_extras(result.required_extras)(result.__init__)
result.load_from_checkpoint = classmethod(
_requires_extras(result.required_extras)(result.load_from_checkpoint.__func__)
)
load_from_checkpoint = getattr(result, "load_from_checkpoint", None)
if load_from_checkpoint is not None:
result.load_from_checkpoint = classmethod(
_requires_extras(result.required_extras)(result.load_from_checkpoint.__func__)
)
return result


Expand Down

0 comments on commit d92090d

Please sign in to comment.