Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add graphviz to CI and torch 1.9.0 support (#413)
Browse files Browse the repository at this point in the history
* Try something

* Try fix

* Try fix

* Fixes

* Try fix

* Updates
  • Loading branch information
ethanwharris authored Jun 15, 2021
1 parent 72f5a9e commit 01fef90
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 34 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ jobs:
brew update
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
- name: Install graphviz
if: matrix.topic == 'serve'
run: |
sudo apt-get install graphviz
- name: Set min. dependencies
if: matrix.requires == 'minimal'
run: |
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _compare_version(package: str, op, version) -> bool:


_TORCH_AVAILABLE = _module_available("torch")
_BOLTS_AVAILABLE = _module_available("pl_bolts") and _compare_version("torch", operator.lt, "1.9.0")
_PANDAS_AVAILABLE = _module_available("pandas")
_SKLEARN_AVAILABLE = _module_available("sklearn")
_TABNET_AVAILABLE = _module_available("pytorch_tabnet")
Expand Down
4 changes: 2 additions & 2 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from torch import nn as nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE

if _TIMM_AVAILABLE:
import timm
Expand Down
2 changes: 1 addition & 1 deletion flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ImageEmbedder(Task):
def __init__(
self,
embedding_dim: Optional[int] = None,
backbone: str = "swav-imagenet",
backbone: str = "resnet101",
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import torch.nn as nn
from deprecate import deprecated
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.backbones import catch_url_error

if _TORCHVISION_AVAILABLE:
Expand Down
4 changes: 2 additions & 2 deletions flash_examples/finetuning/image_classification_multi_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], L
metrics=F1(num_classes=len(genres)),
)

# 4. Create the trainer. Train on 2 gpus for 10 epochs.
trainer = flash.Trainer(max_epochs=10)
# 4. Create the trainer
trainer = flash.Trainer(fast_dev_run=True)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
Expand Down
5 changes: 2 additions & 3 deletions flash_examples/predict/image_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Create an ImageEmbedder with swav trained on imagenet.
# Check out SWAV: https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav
embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128)
# 2. Create an ImageEmbedder with resnet101 trained on imagenet.
embedder = ImageEmbedder(backbone="resnet101", embedding_dim=128)

# 3. Generate an embedding from an image path.
embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"])
Expand Down
3 changes: 2 additions & 1 deletion tests/image/segmentation/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE

from flash.core.utilities.imports import _BOLTS_AVAILABLE
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES


Expand Down
4 changes: 2 additions & 2 deletions tests/image/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import urllib.error

import pytest
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE

from flash.core.utilities.imports import _TIMM_AVAILABLE
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE
from flash.image.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES


Expand Down
39 changes: 18 additions & 21 deletions tests/serve/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,9 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj):
},
"session": "UUID",
}
# TODO: Add graphviz to CI
# resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
# assert resp.headers["content-type"] == "text/html; charset=utf-8"
# assert resp.template.name == "dag.html"
resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
assert resp.headers["content-type"] == "text/html; charset=utf-8"
assert resp.template.name == "dag.html"


@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
Expand Down Expand Up @@ -317,10 +316,9 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad
},
"session": "UUID",
}
# TODO: Add graphviz to CI
# resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
# assert resp.headers["content-type"] == "text/html; charset=utf-8"
# assert resp.template.name == "dag.html"
resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
assert resp.headers["content-type"] == "text/html; charset=utf-8"
assert resp.template.name == "dag.html"


@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
Expand Down Expand Up @@ -385,19 +383,18 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ
app = composit.serve(host="0.0.0.0", port=8000)

with TestClient(app) as tc:
# TODO: Add graphviz to CI
# resp = tc.get("http://127.0.0.1:8000/gridserve/component_dags")
# assert resp.headers["content-type"] == "text/html; charset=utf-8"
# assert resp.template.name == "dag.html"
# resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
# assert resp.headers["content-type"] == "text/html; charset=utf-8"
# assert resp.template.name == "dag.html"
# resp = tc.get("http://127.0.0.1:8000/predict_seat_img/dag")
# assert resp.headers["content-type"] == "text/html; charset=utf-8"
# assert resp.template.name == "dag.html"
# resp = tc.get("http://127.0.0.1:8000/predict_seat_img_two/dag")
# assert resp.headers["content-type"] == "text/html; charset=utf-8"
# assert resp.template.name == "dag.html"
resp = tc.get("http://127.0.0.1:8000/gridserve/component_dags")
assert resp.headers["content-type"] == "text/html; charset=utf-8"
assert resp.template.name == "dag.html"
resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
assert resp.headers["content-type"] == "text/html; charset=utf-8"
assert resp.template.name == "dag.html"
resp = tc.get("http://127.0.0.1:8000/predict_seat_img/dag")
assert resp.headers["content-type"] == "text/html; charset=utf-8"
assert resp.template.name == "dag.html"
resp = tc.get("http://127.0.0.1:8000/predict_seat_img_two/dag")
assert resp.headers["content-type"] == "text/html; charset=utf-8"
assert resp.template.name == "dag.html"

with (session_global_datadir / "cat.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
Expand Down

0 comments on commit 01fef90

Please sign in to comment.