Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add support for Torch ORT to Transformer based Tasks #667

Merged
merged 22 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

- Added Torch ORT support to Transformer based tasks ([#667](https://github.com/PyTorchLightning/lightning-flash/pull/667))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
18 changes: 18 additions & 0 deletions docs/source/reference/summarization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/summarization/client.py
:language: python
:lines: 14-

------

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``SummarizationTask`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python

...

model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
18 changes: 18 additions & 0 deletions docs/source/reference/text_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/text_classification/client.py
:language: python
:lines: 14-

------

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python

...

model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
18 changes: 18 additions & 0 deletions docs/source/reference/translation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/translation/client.py
:language: python
:lines: 14-

------

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TranslationTask`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python

...

model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _compare_version(package: str, op, version) -> bool:
_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
_DATASETS_AVAILABLE = _module_available("datasets")
_ICEVISION_AVAILABLE = _module_available("icevision")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
Expand Down
25 changes: 17 additions & 8 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
from pytorch_lightning import Callback
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Labels
from flash.core.data.process import Serializer
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text.ort_callback import ORTCallback

if _TEXT_AVAILABLE:
from transformers import BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import AutoModelForSequenceClassification
from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput


class TextClassifier(ClassificationTask):
Expand All @@ -43,6 +45,7 @@ class TextClassifier(ClassificationTask):
learning_rate: Learning rate to use for training, defaults to `1e-3`
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

required_extras: str = "text"
Expand All @@ -57,6 +60,7 @@ def __init__(
learning_rate: float = 1e-2,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
enable_ort: bool = False,
):
self.save_hyperparameters()

Expand All @@ -76,25 +80,24 @@ def __init__(
multi_label=multi_label,
serializer=serializer or Labels(multi_label=multi_label),
)
self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)

self.enable_ort = enable_ort
self.model = AutoModelForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)
self.save_hyperparameters()

@property
def backbone(self):
# see huggingface's BertForSequenceClassification
return self.model.bert
return self.model.base_model

def forward(self, batch: Dict[str, torch.Tensor]):
return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None))

def to_loss_format(self, x) -> torch.Tensor:
if isinstance(x, SequenceClassifierOutput):
if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
x = x.logits
return super().to_loss_format(x)

def to_metrics_format(self, x) -> torch.Tensor:
if isinstance(x, SequenceClassifierOutput):
if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
x = x.logits
return super().to_metrics_format(x)

Expand All @@ -112,3 +115,9 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"]
else:
assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"]

def configure_callbacks(self) -> List[Callback]:
callbacks = super().configure_callbacks() or []
if self.enable_ort:
callbacks.append(ORTCallback())
return callbacks
52 changes: 52 additions & 0 deletions flash/text/ort_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash import Trainer
from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE

if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule


class ORTCallback(Callback):
"""Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime.

Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for
training and inference.

Usage:

# via Transformer Tasks
model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)

# or via the trainer
trainer = flash.Trainer(callbacks=ORTCallback())
"""

def __init__(self):
if not _TORCH_ORT_AVAILABLE:
raise MisconfigurationException(
"Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort"
)

def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None:
if not hasattr(pl_module, "model"):
raise MisconfigurationException(
"Torch ORT requires to wrap a single model that defines a forward function "
"assigned as `model` inside the `LightningModule`."
)
if not isinstance(pl_module.model, ORTModule):
pl_module.model = ORTModule(pl_module.model)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 11 additions & 0 deletions flash/text/seq2seq/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union

import torch
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_info
from torch import Tensor
from torchmetrics import Metric

from flash.core.finetuning import FlashBaseFinetuning
from flash.core.model import Task
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text.ort_callback import ORTCallback
from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings

if _TEXT_AVAILABLE:
Expand Down Expand Up @@ -54,6 +56,7 @@ class Seq2SeqTask(Task):
learning_rate: Learning rate to use for training, defaults to `3e-4`
val_target_max_length: Maximum length of targets in validation. Defaults to `128`
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

required_extras: str = "text"
Expand All @@ -67,6 +70,7 @@ def __init__(
learning_rate: float = 5e-5,
val_target_max_length: Optional[int] = None,
num_beams: Optional[int] = None,
enable_ort: bool = False,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
Expand All @@ -75,6 +79,7 @@ def __init__(
os.environ["PYTHONWARNINGS"] = "ignore"
super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate)
self.model = AutoModelForSeq2SeqLM.from_pretrained(backbone)
self.enable_ort = enable_ort
self.val_target_max_length = val_target_max_length
self.num_beams = num_beams
self._initialize_model_specific_parameters()
Expand Down Expand Up @@ -134,3 +139,9 @@ def tokenize_labels(self, labels: Tensor) -> List[str]:

def configure_finetune_callback(self) -> List[FlashBaseFinetuning]:
return [Seq2SeqFreezeEmbeddings(self.model.config.model_type, train_bn=True)]

def configure_callbacks(self) -> List[Callback]:
callbacks = super().configure_callbacks() or []
if self.enable_ort:
callbacks.append(ORTCallback())
return callbacks
3 changes: 3 additions & 0 deletions flash/text/seq2seq/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class QuestionAnsweringTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

def __init__(
Expand All @@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
use_stemmer: bool = True,
rouge_newline_sep: bool = True,
enable_ort: bool = True,
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
):
self.save_hyperparameters()
super().__init__(
Expand All @@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
enable_ort=enable_ort,
)
self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
Expand Down
3 changes: 3 additions & 0 deletions flash/text/seq2seq/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SummarizationTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

def __init__(
Expand All @@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
use_stemmer: bool = True,
rouge_newline_sep: bool = True,
enable_ort: bool = True,
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
):
self.save_hyperparameters()
super().__init__(
Expand All @@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
enable_ort=enable_ort,
)
self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
Expand Down
3 changes: 3 additions & 0 deletions flash/text/seq2seq/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class TranslationTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
n_gram: Maximum n_grams to use in metric calculation. Defaults to `4`
smooth: Apply smoothing in BLEU calculation. Defaults to `True`
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

def __init__(
Expand All @@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
n_gram: bool = 4,
smooth: bool = True,
enable_ort: bool = True,
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
):
self.save_hyperparameters()
super().__init__(
Expand All @@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
enable_ort=enable_ort,
)
self.bleu = BLEUScore(
n_gram=n_gram,
Expand Down
62 changes: 62 additions & 0 deletions tests/text/classification/test_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from pytorch_lightning import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash import Trainer
from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE
from flash.text import TextClassifier
from flash.text.ort_callback import ORTCallback
from tests.helpers.boring_model import BoringModel
from tests.helpers.utils import _TEXT_TESTING
from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE

if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_init_train_enable_ort(tmpdir):
class TestCallback(Callback):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
assert isinstance(pl_module.model, ORTModule)

model = TextClassifier(2, TEST_BACKBONE, enable_ort=True)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback())
trainer.fit(
model,
train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
)
trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset()))


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_ort_callback_fails_no_model(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback())
with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"):
trainer.fit(
model,
train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
)