From b49bf04cb0e802b02c2cd9f81e6a706f9b6895ae Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 10 Jun 2021 12:47:55 +0100 Subject: [PATCH] Jit support (#389) * Add support for jit script * Add jit test to image classification * Add jit for object detection * Add style transfer jit * Add jit support for embedding * Add jit for tabular and template * init * Add jit for text classification * Add video classification jit * Add seq2seq jit * Fixes * Add jit support matrix * Update CHANGELOG.md --- CHANGELOG.md | 1 + docs/source/general/jit.rst | 59 +++++++++++++++++++ docs/source/index.rst | 1 + flash/core/data/process.py | 4 +- flash/core/model.py | 8 ++- flash/image/classification/model.py | 3 - flash/image/detection/model.py | 9 ++- flash/image/embedding/model.py | 15 +++-- flash/image/segmentation/model.py | 5 +- flash/tabular/classification/model.py | 6 +- flash/text/classification/model.py | 32 +++++++++- tests/image/classification/test_model.py | 20 +++++++ tests/image/detection/test_model.py | 22 +++++++ tests/image/embedding/__init__.py | 0 tests/image/embedding/test_model.py | 38 ++++++++++++ tests/image/segmentation/test_model.py | 19 ++++++ tests/image/style_transfer/__init__.py | 0 tests/image/style_transfer/test_model.py | 20 +++++++ tests/tabular/classification/test_model.py | 20 +++++++ tests/template/classification/test_model.py | 20 +++++++ tests/text/classification/test_model.py | 19 ++++++ tests/text/seq2seq/core/__init__.py | 0 .../text/seq2seq/summarization/test_model.py | 21 +++++++ tests/text/seq2seq/translation/test_model.py | 21 +++++++ tests/video/__init__.py | 0 tests/video/classification/__init__.py | 0 .../test_model.py} | 19 ++++++ 27 files changed, 358 insertions(+), 24 deletions(-) create mode 100644 docs/source/general/jit.rst create mode 100644 tests/image/embedding/__init__.py create mode 100644 tests/image/embedding/test_model.py create mode 100644 tests/image/style_transfer/__init__.py create mode 100644 tests/text/seq2seq/core/__init__.py create mode 100644 tests/video/__init__.py create mode 100644 tests/video/classification/__init__.py rename tests/video/{test_video_classifier.py => classification/test_model.py} (91%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e9c90803b..7bc202168f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389)) ### Changed diff --git a/docs/source/general/jit.rst b/docs/source/general/jit.rst new file mode 100644 index 0000000000..2a6607f33c --- /dev/null +++ b/docs/source/general/jit.rst @@ -0,0 +1,59 @@ +####################### +TorchScript JIT Support +####################### + +.. _jit: + +We test all of our tasks for compatibility with :mod:`torch.jit`. +This table gives a breakdown of the supported features. + +.. list-table:: + :widths: 25 25 25 25 + :header-rows: 1 + + * - Task + - :func:`torch.jit.script` + - :func:`torch.jit.trace` + - :func:`torch.jit.save` + * - :class:`~flash.image.classification.model.ImageClassifier` + - Yes + - Yes + - Yes + * - :class:`~flash.image.detection.model.ObjectDetector` + - Yes + - No + - Yes + * - :class:`~flash.image.embedding.model.ImageEmbedder` + - Yes + - Yes + - Yes + * - :class:`~flash.image.segmentation.model.SemanticSegmentation` + - Yes + - Yes + - Yes + * - :class:`~flash.image.style_transfer.model.StyleTransfer` + - No + - Yes + - Yes + * - :class:`~flash.tabular.classification.model.TabularClassifier` + - No + - Yes + - No + * - :class:`~flash.text.classification.model.TabularClassifier` + - No + - Yes :sup:`*` + - Yes + * - :class:`~flash.text.seq2seq.summarization.model.SummarizationTask` + - No + - Yes + - Yes + * - :class:`~flash.text.seq2seq.translation.model.TranslationTask` + - No + - Yes + - Yes + * - :class:`~flash.video.classification.model.VideoClassifier` + - No + - Yes + - Yes + +:sup:`*` Only with ``strict=False``. diff --git a/docs/source/index.rst b/docs/source/index.rst index 34df14a940..8fb3169d28 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -49,6 +49,7 @@ Lightning Flash general/training general/finetuning general/predictions + general/jit .. toctree:: diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 3002ee4275..ac61ca2f51 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -49,7 +49,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): pass -class Preprocess(BasePreprocess, Properties, Module): +class Preprocess(BasePreprocess, Properties): """The :class:`~flash.core.data.process.Preprocess` encapsulates all the data processing logic that should run before the data is passed to the model. It is particularly useful when you want to provide an end to end implementation which works with 4 different stages: ``train``, ``validation``, ``test``, and inference (``predict``). @@ -454,7 +454,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): return cls(**state_dict) -class Postprocess(Properties, Module): +class Postprocess(Properties): def __init__(self, save_path: Optional[str] = None): super().__init__() diff --git a/flash/core/model.py b/flash/core/model.py index 0b344cc756..aeef402e27 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -263,12 +263,14 @@ def _resolve( return preprocess, postprocess, serializer + @torch.jit.unused @property def serializer(self) -> Optional[Serializer]: """The current :class:`.Serializer` associated with this model. If this property was set to a mapping (e.g. ``.serializer = {'output1': SerializerOne()}``) then this will be a :class:`.MappingSerializer`.""" return self._serializer + @torch.jit.unused @serializer.setter def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]): if isinstance(serializer, Mapping): @@ -350,12 +352,14 @@ def build_data_pipeline( self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) return data_pipeline + @torch.jit.unused @property def data_pipeline(self) -> DataPipeline: """The current :class:`.DataPipeline`. If set, the new value will override the :class:`.Task` defaults. See :py:meth:`~build_data_pipeline` for more details on the resolution order.""" return self.build_data_pipeline() + @torch.jit.unused @data_pipeline.setter def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: self._preprocess, self._postprocess, self.serializer = Task._resolve( @@ -366,14 +370,16 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: getattr(data_pipeline, '_postprocess_pipeline', None), getattr(data_pipeline, '_serializer', None), ) - self._preprocess.state_dict() + # self._preprocess.state_dict() if getattr(self._preprocess, "_ddp_params_and_buffers_to_ignore", None): self._ddp_params_and_buffers_to_ignore = self._preprocess._ddp_params_and_buffers_to_ignore + @torch.jit.unused @property def preprocess(self) -> Preprocess: return getattr(self.data_pipeline, '_preprocess_pipeline', None) + @torch.jit.unused @property def postprocess(self) -> Postprocess: return getattr(self.data_pipeline, '_postprocess_pipeline', None) diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index e172e5fe86..a802ca425a 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -14,14 +14,11 @@ from types import FunctionType from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union -import pytorch_lightning as pl import torch import torchmetrics -from pytorch_lightning.callbacks.base import Callback from torch import nn from torch.optim.lr_scheduler import _LRScheduler -import flash from flash.core.classification import ClassificationTask from flash.core.data.data_source import DefaultDataKeys from flash.core.data.process import Serializer diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 097bff8917..f26d219224 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -163,6 +163,9 @@ def get_model( model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator) return model + def forward(self, x: List[torch.Tensor]) -> Any: + return self.model(x) + def training_step(self, batch, batch_idx) -> Any: """The training step. Overrides ``Task.training_step`` """ @@ -178,7 +181,7 @@ def training_step(self, batch, batch_idx) -> Any: def validation_step(self, batch, batch_idx): images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] # fasterrcnn takes only images for eval() mode - outs = self.model(images) + outs = self(images) iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() self.log("val_iou", iou) @@ -188,13 +191,13 @@ def on_validation_end(self) -> None: def test_step(self, batch, batch_idx): images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] # fasterrcnn takes only images for eval() mode - outs = self.model(images) + outs = self(images) iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() self.log("test_iou", iou) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: images = batch[DefaultDataKeys.INPUT] - return self.model(images) + return self(images) def configure_finetune_callback(self): return [ObjectDetectionFineTuning(train_bn=True)] diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 3fa3df69db..06ee9152b2 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, Union import torch from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -89,13 +89,12 @@ def __init__( rank_zero_warn('embedding_dim. Remember to finetune first!') def apply_pool(self, x): - if self.pooling_fn == torch.max: - # torch.max also returns argmax - x = self.pooling_fn(x, dim=-1)[0] - x = self.pooling_fn(x, dim=-1)[0] - else: - x = self.pooling_fn(x, dim=-1) - x = self.pooling_fn(x, dim=-1) + x = self.pooling_fn(x, dim=-1) + if torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]): + x = x[0] + x = self.pooling_fn(x, dim=-1) + if torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]): + x = x[0] return x def forward(self, x) -> torch.Tensor: diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 0e1dd06f7d..b24d4e9476 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -136,13 +136,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def forward(self, x) -> torch.Tensor: # infer the image to the model - res: Union[torch.Tensor, Dict[str, torch.Tensor]] = self.backbone(x) + res = self.backbone(x) # some frameworks like torchvision return a dict. # In particular, torchvision segmentation models return the output logits # in the key `out`. - out: torch.Tensor - if isinstance(res, dict): + if torch.jit.isinstance(res, Dict[str, torch.Tensor]): out = res['out'] elif torch.is_tensor(res): out = res diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 57c026bf88..baee3fbccd 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -83,7 +83,11 @@ def __init__( def forward(self, x_in) -> torch.Tensor: # TabNet takes single input, x_in is composed of (categorical, numerical) - x = torch.cat([x for x in x_in if x.numel()], dim=1) + xs = [] + for x in x_in: + if x.numel(): + xs.append(x) + x = torch.cat(xs, dim=1) return self.model(x)[0] def training_step(self, batch: Any, batch_idx: int) -> Any: diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index e5a2667fe1..ccf98b7db9 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -76,12 +76,35 @@ def backbone(self): # see huggingface's BertForSequenceClassification return self.model.bert - def forward(self, batch_dict): - return self.model(**batch_dict) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None + ): + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) def step(self, batch, batch_idx) -> dict: output = {} - out = self.forward(batch) + out = self.forward(**batch) loss, logits = out[:2] output["loss"] = loss output["y_hat"] = logits @@ -91,6 +114,9 @@ def step(self, batch, batch_idx) -> dict: output["logs"] = {name: metric(probs, batch["labels"]) for name, metric in self.metrics.items()} return output + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(**batch) + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """ This function is used only for debugging usage with CI diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 102dd0ecad..a03bd16a54 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch @@ -108,3 +110,21 @@ def test_multilabel(tmpdir): assert (torch.tensor(predictions) < 0).sum() == 0 assert len(predictions[0]) == num_classes == len(label) assert len(torch.unique(label)) <= 2 + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +def test_jit(tmpdir, jitter, args): + path = os.path.join(tmpdir, "test.pt") + + model = ImageClassifier(2) + model.eval() + + model = jitter(model, *args) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(torch.rand(1, 3, 32, 32)) + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 2]) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index f568e825bb..925d37b6a2 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch from pytorch_lightning import Trainer @@ -75,3 +77,23 @@ def test_training(tmpdir, model): dl = DataLoader(ds, collate_fn=collate_fn) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, dl) + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +def test_jit(tmpdir): + path = os.path.join(tmpdir, "test.pt") + + model = ObjectDetector(2) + model.eval() + + model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model([torch.rand(3, 32, 32)]) + + # torchvision RCNN always returns a (Losses, Detections) tuple in scripting + out = out[1] + + assert {"boxes", "labels", "scores"} <= out[0].keys() diff --git a/tests/image/embedding/__init__.py b/tests/image/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py new file mode 100644 index 0000000000..0c43035451 --- /dev/null +++ b/tests/image/embedding/test_model.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch + +from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.image import ImageEmbedder + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +def test_jit(tmpdir, jitter, args): + path = os.path.join(tmpdir, "test.pt") + + model = ImageEmbedder(embedding_dim=128) + model.eval() + + model = jitter(model, *args) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(torch.rand(1, 3, 32, 32)) + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 128]) diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 176305a25e..a2d256cacf 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Tuple import numpy as np @@ -114,3 +115,21 @@ def test_predict_numpy(): out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (10, 20) + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +def test_jit(tmpdir, jitter, args): + path = os.path.join(tmpdir, "test.pt") + + model = SemanticSegmentation(2) + model.eval() + + model = jitter(model, *args) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(torch.rand(1, 3, 32, 32)) + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 2, 32, 32]) diff --git a/tests/image/style_transfer/__init__.py b/tests/image/style_transfer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index fbcdd6c7ad..4cd1b05c9d 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -1,4 +1,7 @@ +import os + import pytest +import torch from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER, _PYSTICHE_GREATER_EQUAL_0_7_2 from flash.image.style_transfer import StyleTransfer @@ -20,3 +23,20 @@ def test_style_transfer_task(): def test_style_transfer_task_import(): with pytest.raises(ModuleNotFoundError, match="[image_style_transfer]"): StyleTransfer() + + +@pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="image style transfer libraries aren't installed.") +def test_jit(tmpdir): + path = os.path.join(tmpdir, "test.pt") + + model = StyleTransfer() + model.eval() + + model = torch.jit.trace(model, torch.rand(1, 3, 32, 32)) # torch.jit.script doesn't work with pystiche + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(torch.rand(1, 3, 32, 32)) + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 3, 32, 32]) diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 56bb096095..e2fbf11edb 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch from pytorch_lightning import Trainer @@ -70,3 +72,21 @@ def test_init_train_no_cat(tmpdir): def test_module_import_error(tmpdir): with pytest.raises(ModuleNotFoundError, match="[tabular]"): TabularClassifier(num_classes=10, num_features=16, embedding_sizes=[]) + + +@pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") +def test_jit(tmpdir): + model = TabularClassifier(num_classes=10, num_features=8, embedding_sizes=4 * [(10, 32)]) + model.eval() + + # torch.jit.script doesn't work with tabnet + model = torch.jit.trace(model, ((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)), )) + + # TODO: torch.jit.save doesn't work with tabnet + # path = os.path.join(tmpdir, "test.pt") + # torch.jit.save(model, path) + # model = torch.jit.load(path) + + out = model((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4))) + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 10]) diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index aadbb2fa04..9fa57b80b9 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import numpy as np import pytest import torch @@ -116,3 +118,21 @@ def test_predict_sklearn(): data_pipe = DataPipeline(preprocess=TemplatePreprocess()) out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe) assert isinstance(out[0], int) + + +@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16), ))]) +def test_jit(tmpdir, jitter, args): + path = os.path.join(tmpdir, "test.pt") + + model = TemplateSKLearnClassifier(num_features=16, num_classes=10) + model.eval() + + model = jitter(model, *args) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(torch.rand(1, 16)) + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 10]) diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 551bcf1a49..c811fdfa34 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -47,3 +47,22 @@ def test_init_train(tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_dl) + + +@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +def test_jit(tmpdir): + sample_input = torch.randint(1000, size=(1, 100)) + path = os.path.join(tmpdir, "test.pt") + + model = TextClassifier(2, TEST_BACKBONE) + model.eval() + + # Huggingface bert model only supports `torch.jit.trace` with `strict=False` + model = torch.jit.trace(model, sample_input, strict=False) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(sample_input)["logits"] + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 2]) diff --git a/tests/text/seq2seq/core/__init__.py b/tests/text/seq2seq/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index 33c8044aaa..7e380e331e 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -47,3 +47,24 @@ def test_init_train(tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_dl) + + +@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +def test_jit(tmpdir): + sample_input = { + "input_ids": torch.randint(1000, size=(1, 32)), + "attention_mask": torch.randint(1, size=(1, 32)), + } + path = os.path.join(tmpdir, "test.pt") + + model = SummarizationTask(TEST_BACKBONE) + model.eval() + + # Huggingface only supports `torch.jit.trace` + model = torch.jit.trace(model, [sample_input]) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(sample_input) + assert isinstance(out, torch.Tensor) diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index 81ac36a65c..808d0a0ada 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -47,3 +47,24 @@ def test_init_train(tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_dl) + + +@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +def test_jit(tmpdir): + sample_input = { + "input_ids": torch.randint(128, size=(1, 4)), + "attention_mask": torch.randint(1, size=(1, 4)), + } + path = os.path.join(tmpdir, "test.pt") + + model = TranslationTask(TEST_BACKBONE, val_target_max_length=None) + model.eval() + + # Huggingface only supports `torch.jit.trace` + model = torch.jit.trace(model, [sample_input]) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(sample_input) + assert isinstance(out, torch.Tensor) diff --git a/tests/video/__init__.py b/tests/video/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/video/classification/__init__.py b/tests/video/classification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/video/test_video_classifier.py b/tests/video/classification/test_model.py similarity index 91% rename from tests/video/test_video_classifier.py rename to tests/video/classification/test_model.py index caf445af65..a9830cea26 100644 --- a/tests/video/test_video_classifier.py +++ b/tests/video/classification/test_model.py @@ -158,3 +158,22 @@ def test_image_classifier_finetune(tmpdir): trainer = flash.Trainer(fast_dev_run=True) trainer.finetune(model, datamodule=datamodule) + + +@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +def test_jit(tmpdir): + sample_input = torch.rand(1, 3, 32, 256, 256) + path = os.path.join(tmpdir, "test.pt") + + model = VideoClassifier(2, pretrained=False) + model.eval() + + # pytorchvideo only works with `torch.jit.trace` + model = torch.jit.trace(model, sample_input) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(sample_input) + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 2])