Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Jit support #389

Merged
merged 13 commits into from
Jun 10, 2021
4 changes: 2 additions & 2 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
pass


class Preprocess(BasePreprocess, Properties, Module):
class Preprocess(BasePreprocess, Properties):
"""The :class:`~flash.core.data.process.Preprocess` encapsulates all the data processing logic that should run before
the data is passed to the model. It is particularly useful when you want to provide an end to end implementation
which works with 4 different stages: ``train``, ``validation``, ``test``, and inference (``predict``).
Expand Down Expand Up @@ -454,7 +454,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):
return cls(**state_dict)


class Postprocess(Properties, Module):
class Postprocess(Properties):

def __init__(self, save_path: Optional[str] = None):
super().__init__()
Expand Down
8 changes: 7 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,14 @@ def _resolve(

return preprocess, postprocess, serializer

@torch.jit.unused
@property
def serializer(self) -> Optional[Serializer]:
"""The current :class:`.Serializer` associated with this model. If this property was set to a mapping
(e.g. ``.serializer = {'output1': SerializerOne()}``) then this will be a :class:`.MappingSerializer`."""
return self._serializer

@torch.jit.unused
@serializer.setter
def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]):
if isinstance(serializer, Mapping):
Expand Down Expand Up @@ -350,12 +352,14 @@ def build_data_pipeline(
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@torch.jit.unused
@property
def data_pipeline(self) -> DataPipeline:
"""The current :class:`.DataPipeline`. If set, the new value will override the :class:`.Task` defaults. See
:py:meth:`~build_data_pipeline` for more details on the resolution order."""
return self.build_data_pipeline()

@torch.jit.unused
@data_pipeline.setter
def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
self._preprocess, self._postprocess, self.serializer = Task._resolve(
Expand All @@ -366,14 +370,16 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
getattr(data_pipeline, '_postprocess_pipeline', None),
getattr(data_pipeline, '_serializer', None),
)
self._preprocess.state_dict()
# self._preprocess.state_dict()
if getattr(self._preprocess, "_ddp_params_and_buffers_to_ignore", None):
self._ddp_params_and_buffers_to_ignore = self._preprocess._ddp_params_and_buffers_to_ignore

@torch.jit.unused
@property
def preprocess(self) -> Preprocess:
return getattr(self.data_pipeline, '_preprocess_pipeline', None)

@torch.jit.unused
@property
def postprocess(self) -> Postprocess:
return getattr(self.data_pipeline, '_postprocess_pipeline', None)
Expand Down
9 changes: 6 additions & 3 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def get_model(
model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator)
return model

def forward(self, x: List[torch.Tensor]) -> Any:
return self.model(x)

def training_step(self, batch, batch_idx) -> Any:
"""The training step. Overrides ``Task.training_step``
"""
Expand All @@ -178,7 +181,7 @@ def training_step(self, batch, batch_idx) -> Any:
def validation_step(self, batch, batch_idx):
images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
outs = self(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
self.log("val_iou", iou)

Expand All @@ -188,13 +191,13 @@ def on_validation_end(self) -> None:
def test_step(self, batch, batch_idx):
images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
outs = self(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
self.log("test_iou", iou)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
images = batch[DefaultDataKeys.INPUT]
return self.model(images)
return self(images)

def configure_finetune_callback(self):
return [ObjectDetectionFineTuning(train_bn=True)]
Expand Down
15 changes: 7 additions & 8 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from pytorch_lightning.utilities.distributed import rank_zero_warn
Expand Down Expand Up @@ -89,13 +89,12 @@ def __init__(
rank_zero_warn('embedding_dim. Remember to finetune first!')

def apply_pool(self, x):
if self.pooling_fn == torch.max:
# torch.max also returns argmax
x = self.pooling_fn(x, dim=-1)[0]
x = self.pooling_fn(x, dim=-1)[0]
else:
x = self.pooling_fn(x, dim=-1)
x = self.pooling_fn(x, dim=-1)
x = self.pooling_fn(x, dim=-1)
if torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]):
x = x[0]
x = self.pooling_fn(x, dim=-1)
if torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]):
x = x[0]
return x

def forward(self, x) -> torch.Tensor:
Expand Down
5 changes: 2 additions & 3 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A

def forward(self, x) -> torch.Tensor:
# infer the image to the model
res: Union[torch.Tensor, Dict[str, torch.Tensor]] = self.backbone(x)
res = self.backbone(x)

# some frameworks like torchvision return a dict.
# In particular, torchvision segmentation models return the output logits
# in the key `out`.
out: torch.Tensor
if isinstance(res, dict):
if torch.jit.isinstance(res, Dict[str, torch.Tensor]):
out = res['out']
elif torch.is_tensor(res):
out = res
Expand Down
6 changes: 5 additions & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ def __init__(

def forward(self, x_in) -> torch.Tensor:
# TabNet takes single input, x_in is composed of (categorical, numerical)
x = torch.cat([x for x in x_in if x.numel()], dim=1)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
xs = []
for x in x_in:
if x.numel():
xs.append(x)
x = torch.cat(xs, dim=1)
return self.model(x)[0]

def training_step(self, batch: Any, batch_idx: int) -> Any:
Expand Down
20 changes: 20 additions & 0 deletions tests/image/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch

Expand Down Expand Up @@ -108,3 +110,21 @@ def test_multilabel(tmpdir):
assert (torch.tensor(predictions) < 0).sum() == 0
assert len(predictions[0]) == num_classes == len(label)
assert len(torch.unique(label)) <= 2


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")

model = ImageClassifier(2)
model.eval()

model = jitter(model, *args)

torch.jit.save(model, path)
model = torch.jit.load(path)

out = model(torch.rand(1, 3, 32, 32))
assert isinstance(out, torch.Tensor)
assert out.shape == torch.Size([1, 2])
22 changes: 22 additions & 0 deletions tests/image/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -75,3 +77,23 @@ def test_training(tmpdir, model):
dl = DataLoader(ds, collate_fn=collate_fn)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, dl)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
def test_jit(tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = ObjectDetector(2)
model.eval()

model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN

torch.jit.save(model, path)
model = torch.jit.load(path)

out = model([torch.rand(3, 32, 32)])

# torchvision RCNN always returns a (Losses, Detections) tuple in scripting
out = out[1]

assert {"boxes", "labels", "scores"} <= out[0].keys()
Empty file.
38 changes: 38 additions & 0 deletions tests/image/embedding/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch

from flash.core.utilities.imports import _IMAGE_AVAILABLE
from flash.image import ImageEmbedder


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")

model = ImageEmbedder(embedding_dim=128)
model.eval()

model = jitter(model, *args)

torch.jit.save(model, path)
model = torch.jit.load(path)

out = model(torch.rand(1, 3, 32, 32))
assert isinstance(out, torch.Tensor)
assert out.shape == torch.Size([1, 128])
19 changes: 19 additions & 0 deletions tests/image/segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -114,3 +115,21 @@ def test_predict_numpy():
out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
assert isinstance(out[0], torch.Tensor)
assert out[0].shape == (10, 20)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")

model = SemanticSegmentation(2)
model.eval()

model = jitter(model, *args)

torch.jit.save(model, path)
model = torch.jit.load(path)

out = model(torch.rand(1, 3, 32, 32))
assert isinstance(out, torch.Tensor)
assert out.shape == torch.Size([1, 2, 32, 32])
Empty file.
20 changes: 20 additions & 0 deletions tests/image/style_transfer/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os

import pytest
import torch

from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER, _PYSTICHE_GREATER_EQUAL_0_7_2
from flash.image.style_transfer import StyleTransfer
Expand All @@ -20,3 +23,20 @@ def test_style_transfer_task():
def test_style_transfer_task_import():
with pytest.raises(ModuleNotFoundError, match="[image_style_transfer]"):
StyleTransfer()


@pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="image style transfer libraries aren't installed.")
def test_jit(tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = StyleTransfer()
model.eval()

model = torch.jit.trace(model, torch.rand(1, 3, 32, 32)) # torch.jit.script doesn't work with pystiche

torch.jit.save(model, path)
model = torch.jit.load(path)

out = model(torch.rand(1, 3, 32, 32))
assert isinstance(out, torch.Tensor)
assert out.shape == torch.Size([1, 3, 32, 32])
20 changes: 20 additions & 0 deletions tests/tabular/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -70,3 +72,21 @@ def test_init_train_no_cat(tmpdir):
def test_module_import_error(tmpdir):
with pytest.raises(ModuleNotFoundError, match="[tabular]"):
TabularClassifier(num_classes=10, num_features=16, embedding_sizes=[])


@pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed.")
def test_jit(tmpdir):
model = TabularClassifier(num_classes=10, num_features=8, embedding_sizes=4 * [(10, 32)])
model.eval()

# torch.jit.script doesn't work with tabnet
model = torch.jit.trace(model, ((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)), ))

# TODO: torch.jit.save doesn't work with tabnet
# path = os.path.join(tmpdir, "test.pt")
# torch.jit.save(model, path)
# model = torch.jit.load(path)

out = model((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)))
assert isinstance(out, torch.Tensor)
assert out.shape == torch.Size([1, 10])
20 changes: 20 additions & 0 deletions tests/template/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -116,3 +118,21 @@ def test_predict_sklearn():
data_pipe = DataPipeline(preprocess=TemplatePreprocess())
out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe)
assert isinstance(out[0], int)


@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16), ))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")

model = TemplateSKLearnClassifier(num_features=16, num_classes=10)
model.eval()

model = jitter(model, *args)

torch.jit.save(model, path)
model = torch.jit.load(path)

out = model(torch.rand(1, 16))
assert isinstance(out, torch.Tensor)
assert out.shape == torch.Size([1, 10])