Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add support for Torch ORT to Transformer based Tasks (#667)
Browse files Browse the repository at this point in the history
* Add torch ORT support, move transformer Tasks to use general task class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformers version

* Revert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add tests

* Add tests

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add docs for text classification and translation

* Add note

* Add CHANGELOG.md

* Address code review

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
3 people authored Aug 17, 2021
1 parent 4e89a37 commit 741a838
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

- Added Torch ORT support to Transformer based tasks ([#667](https://github.com/PyTorchLightning/lightning-flash/pull/667))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
18 changes: 18 additions & 0 deletions docs/source/reference/summarization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/summarization/client.py
:language: python
:lines: 14-

------

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``SummarizationTask`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python
...
model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
18 changes: 18 additions & 0 deletions docs/source/reference/text_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/text_classification/client.py
:language: python
:lines: 14-

------

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python
...
model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
18 changes: 18 additions & 0 deletions docs/source/reference/translation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/translation/client.py
:language: python
:lines: 14-

------

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TranslationTask`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python
...
model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _compare_version(package: str, op, version) -> bool:
_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
_DATASETS_AVAILABLE = _module_available("datasets")
_ICEVISION_AVAILABLE = _module_available("icevision")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
Expand Down
25 changes: 17 additions & 8 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
from pytorch_lightning import Callback
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Labels
from flash.core.data.process import Serializer
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text.ort_callback import ORTCallback

if _TEXT_AVAILABLE:
from transformers import BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import AutoModelForSequenceClassification
from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput


class TextClassifier(ClassificationTask):
Expand All @@ -43,6 +45,7 @@ class TextClassifier(ClassificationTask):
learning_rate: Learning rate to use for training, defaults to `1e-3`
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

required_extras: str = "text"
Expand All @@ -57,6 +60,7 @@ def __init__(
learning_rate: float = 1e-2,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
enable_ort: bool = False,
):
self.save_hyperparameters()

Expand All @@ -76,25 +80,24 @@ def __init__(
multi_label=multi_label,
serializer=serializer or Labels(multi_label=multi_label),
)
self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)

self.enable_ort = enable_ort
self.model = AutoModelForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)
self.save_hyperparameters()

@property
def backbone(self):
# see huggingface's BertForSequenceClassification
return self.model.bert
return self.model.base_model

def forward(self, batch: Dict[str, torch.Tensor]):
return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None))

def to_loss_format(self, x) -> torch.Tensor:
if isinstance(x, SequenceClassifierOutput):
if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
x = x.logits
return super().to_loss_format(x)

def to_metrics_format(self, x) -> torch.Tensor:
if isinstance(x, SequenceClassifierOutput):
if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
x = x.logits
return super().to_metrics_format(x)

Expand All @@ -112,3 +115,9 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"]
else:
assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"]

def configure_callbacks(self) -> List[Callback]:
callbacks = super().configure_callbacks() or []
if self.enable_ort:
callbacks.append(ORTCallback())
return callbacks
52 changes: 52 additions & 0 deletions flash/text/ort_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash import Trainer
from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE

if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule


class ORTCallback(Callback):
"""Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime.
Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for
training and inference.
Usage:
# via Transformer Tasks
model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
# or via the trainer
trainer = flash.Trainer(callbacks=ORTCallback())
"""

def __init__(self):
if not _TORCH_ORT_AVAILABLE:
raise MisconfigurationException(
"Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort"
)

def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None:
if not hasattr(pl_module, "model"):
raise MisconfigurationException(
"Torch ORT requires to wrap a single model that defines a forward function "
"assigned as `model` inside the `LightningModule`."
)
if not isinstance(pl_module.model, ORTModule):
pl_module.model = ORTModule(pl_module.model)
11 changes: 11 additions & 0 deletions flash/text/seq2seq/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union

import torch
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_info
from torch import Tensor
from torchmetrics import Metric

from flash.core.finetuning import FlashBaseFinetuning
from flash.core.model import Task
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text.ort_callback import ORTCallback
from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings

if _TEXT_AVAILABLE:
Expand Down Expand Up @@ -54,6 +56,7 @@ class Seq2SeqTask(Task):
learning_rate: Learning rate to use for training, defaults to `3e-4`
val_target_max_length: Maximum length of targets in validation. Defaults to `128`
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

required_extras: str = "text"
Expand All @@ -67,6 +70,7 @@ def __init__(
learning_rate: float = 5e-5,
val_target_max_length: Optional[int] = None,
num_beams: Optional[int] = None,
enable_ort: bool = False,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
Expand All @@ -75,6 +79,7 @@ def __init__(
os.environ["PYTHONWARNINGS"] = "ignore"
super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate)
self.model = AutoModelForSeq2SeqLM.from_pretrained(backbone)
self.enable_ort = enable_ort
self.val_target_max_length = val_target_max_length
self.num_beams = num_beams
self._initialize_model_specific_parameters()
Expand Down Expand Up @@ -134,3 +139,9 @@ def tokenize_labels(self, labels: Tensor) -> List[str]:

def configure_finetune_callback(self) -> List[FlashBaseFinetuning]:
return [Seq2SeqFreezeEmbeddings(self.model.config.model_type, train_bn=True)]

def configure_callbacks(self) -> List[Callback]:
callbacks = super().configure_callbacks() or []
if self.enable_ort:
callbacks.append(ORTCallback())
return callbacks
3 changes: 3 additions & 0 deletions flash/text/seq2seq/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class QuestionAnsweringTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

def __init__(
Expand All @@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
use_stemmer: bool = True,
rouge_newline_sep: bool = True,
enable_ort: bool = False,
):
self.save_hyperparameters()
super().__init__(
Expand All @@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
enable_ort=enable_ort,
)
self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
Expand Down
3 changes: 3 additions & 0 deletions flash/text/seq2seq/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SummarizationTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

def __init__(
Expand All @@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
use_stemmer: bool = True,
rouge_newline_sep: bool = True,
enable_ort: bool = False,
):
self.save_hyperparameters()
super().__init__(
Expand All @@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
enable_ort=enable_ort,
)
self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
Expand Down
3 changes: 3 additions & 0 deletions flash/text/seq2seq/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class TranslationTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
n_gram: Maximum n_grams to use in metric calculation. Defaults to `4`
smooth: Apply smoothing in BLEU calculation. Defaults to `True`
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

def __init__(
Expand All @@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
n_gram: bool = 4,
smooth: bool = True,
enable_ort: bool = False,
):
self.save_hyperparameters()
super().__init__(
Expand All @@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
enable_ort=enable_ort,
)
self.bleu = BLEUScore(
n_gram=n_gram,
Expand Down
62 changes: 62 additions & 0 deletions tests/text/classification/test_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from pytorch_lightning import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash import Trainer
from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE
from flash.text import TextClassifier
from flash.text.ort_callback import ORTCallback
from tests.helpers.boring_model import BoringModel
from tests.helpers.utils import _TEXT_TESTING
from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE

if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_init_train_enable_ort(tmpdir):
class TestCallback(Callback):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
assert isinstance(pl_module.model, ORTModule)

model = TextClassifier(2, TEST_BACKBONE, enable_ort=True)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback())
trainer.fit(
model,
train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
)
trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset()))


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_ort_callback_fails_no_model(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback())
with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"):
trainer.fit(
model,
train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
)

0 comments on commit 741a838

Please sign in to comment.