From 39a6009970b89cd3aed485c25748319ee8694384 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 16 Aug 2021 19:25:48 +0100 Subject: [PATCH 01/19] Add torch ORT support, move transformer Tasks to use general task class --- flash/core/utilities/imports.py | 1 + flash/text/classification/model.py | 27 +++++++--- flash/text/ort_callback.py | 52 +++++++++++++++++++ flash/text/seq2seq/core/model.py | 11 ++++ .../text/seq2seq/question_answering/model.py | 3 ++ flash/text/seq2seq/summarization/model.py | 3 ++ flash/text/seq2seq/translation/model.py | 3 ++ 7 files changed, 93 insertions(+), 7 deletions(-) create mode 100644 flash/text/ort_callback.py diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index a1375fca9b..8080eefe3c 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -96,6 +96,7 @@ def _compare_version(package: str, op, version) -> bool: _ROUGE_SCORE_AVAILABLE = _module_available("rouge_score") _SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") _DATASETS_AVAILABLE = _module_available("datasets") +_TORCH_ORT_AVAILABLE = _module_available("torch_ort") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 3a0d78e1ff..ff03acd9b5 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -13,18 +13,20 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union, Tuple import torch +from pytorch_lightning import Callback from torchmetrics import Accuracy, F1, Metric from flash.core.classification import ClassificationTask, Labels from flash.core.data.process import Serializer from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.text.ort_callback import ORTCallback if _TEXT_AVAILABLE: - from transformers import BertForSequenceClassification - from transformers.modeling_outputs import SequenceClassifierOutput + from transformers import AutoModelForSequenceClassification + from transformers.modeling_outputs import SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput class TextClassifier(ClassificationTask): @@ -43,6 +45,7 @@ class TextClassifier(ClassificationTask): learning_rate: Learning rate to use for training, defaults to `1e-3` multi_label: Whether the targets are multi-label or not. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ required_extras: str = "text" @@ -57,6 +60,7 @@ def __init__( learning_rate: float = 1e-2, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + enable_ort: bool = False, ): self.save_hyperparameters() @@ -75,14 +79,17 @@ def __init__( multi_label=multi_label, serializer=serializer or Labels(multi_label=multi_label), ) - self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes) - + self.enable_ort = enable_ort + self.model = AutoModelForSequenceClassification.from_pretrained(backbone, num_labels=num_classes) self.save_hyperparameters() @property def backbone(self): - # see huggingface's BertForSequenceClassification - return self.model.bert + return self.model.base_model + + @staticmethod + def apply_filtering(y: torch.Tensor, y_hat: Seq2SeqSequenceClassifierOutput) -> Tuple[torch.Tensor, torch.Tensor]: + return y, y_hat.logits def forward(self, batch: Dict[str, torch.Tensor]): return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None)) @@ -111,3 +118,9 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"] else: assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"] + + def configure_callbacks(self) -> List[Callback]: + callbacks = super().configure_callbacks() or [] + if self.enable_ort: + callbacks.append(ORTCallback()) + return callbacks diff --git a/flash/text/ort_callback.py b/flash/text/ort_callback.py new file mode 100644 index 0000000000..afb8d9695e --- /dev/null +++ b/flash/text/ort_callback.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE + +if _TORCH_ORT_AVAILABLE: + from torch_ort import ORTModule + + +class ORTCallback(Callback): + """ + Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. + + Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for + training and inference. + + Usage: + + # via Transformer Tasks + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) + + # or via the trainer + trainer = flash.Trainer(callbacks=ORTCallback()) + + """ + + def __init__(self): + if not _TORCH_ORT_AVAILABLE: + raise MisconfigurationException( + "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort" + ) + + def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not hasattr(pl_module, 'model'): + raise MisconfigurationException( + "Torch ORT requires to wrap a single model that defines a forward function " + "assigned as `model` inside the `LightningModule`." + ) + if not isinstance(pl_module.model, ORTModule): + pl_module.model = ORTModule(pl_module.model) diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 283abaf120..fb12140cdd 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -16,6 +16,8 @@ from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union import torch +from flash.text.ort_callback import ORTCallback +from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_info from torch import Tensor from torchmetrics import Metric @@ -54,6 +56,7 @@ class Seq2SeqTask(Task): learning_rate: Learning rate to use for training, defaults to `3e-4` val_target_max_length: Maximum length of targets in validation. Defaults to `128` num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ required_extras: str = "text" @@ -67,6 +70,7 @@ def __init__( learning_rate: float = 5e-5, val_target_max_length: Optional[int] = None, num_beams: Optional[int] = None, + enable_ort: bool = False ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings @@ -75,6 +79,7 @@ def __init__( os.environ["PYTHONWARNINGS"] = "ignore" super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate) self.model = AutoModelForSeq2SeqLM.from_pretrained(backbone) + self.enable_ort = enable_ort self.val_target_max_length = val_target_max_length self.num_beams = num_beams self._initialize_model_specific_parameters() @@ -134,3 +139,9 @@ def tokenize_labels(self, labels: Tensor) -> List[str]: def configure_finetune_callback(self) -> List[FlashBaseFinetuning]: return [Seq2SeqFreezeEmbeddings(self.model.config.model_type, train_bn=True)] + + def configure_callbacks(self) -> List[Callback]: + callbacks = super().configure_callbacks() or [] + if self.enable_ort: + callbacks.append(ORTCallback()) + return callbacks diff --git a/flash/text/seq2seq/question_answering/model.py b/flash/text/seq2seq/question_answering/model.py index 2db3a6d6aa..ae6ea178d0 100644 --- a/flash/text/seq2seq/question_answering/model.py +++ b/flash/text/seq2seq/question_answering/model.py @@ -42,6 +42,7 @@ class QuestionAnsweringTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ def __init__( @@ -55,6 +56,7 @@ def __init__( num_beams: Optional[int] = 4, use_stemmer: bool = True, rouge_newline_sep: bool = True, + enable_ort: bool = True ): self.save_hyperparameters() super().__init__( @@ -65,6 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, + enable_ort=enable_ort ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index af7820b10e..86be9d579a 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -42,6 +42,7 @@ class SummarizationTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ def __init__( @@ -55,6 +56,7 @@ def __init__( num_beams: Optional[int] = 4, use_stemmer: bool = True, rouge_newline_sep: bool = True, + enable_ort: bool = True ): self.save_hyperparameters() super().__init__( @@ -65,6 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, + enable_ort=enable_ort ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index ad99f47e31..49fd032bf2 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -42,6 +42,7 @@ class TranslationTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` n_gram: Maximum n_grams to use in metric calculation. Defaults to `4` smooth: Apply smoothing in BLEU calculation. Defaults to `True` + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ def __init__( @@ -55,6 +56,7 @@ def __init__( num_beams: Optional[int] = 4, n_gram: bool = 4, smooth: bool = True, + enable_ort: bool = True ): self.save_hyperparameters() super().__init__( @@ -65,6 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, + enable_ort=enable_ort ) self.bleu = BLEUScore( n_gram=n_gram, From a48195e495757d2a393e0df6151833b31006d1f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Aug 2021 18:29:15 +0000 Subject: [PATCH 02/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/text/classification/model.py | 4 ++-- flash/text/ort_callback.py | 7 +++---- flash/text/seq2seq/core/model.py | 4 ++-- flash/text/seq2seq/question_answering/model.py | 4 ++-- flash/text/seq2seq/summarization/model.py | 4 ++-- flash/text/seq2seq/translation/model.py | 4 ++-- 6 files changed, 13 insertions(+), 14 deletions(-) diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index cf798d6514..92bbf09f33 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -13,7 +13,7 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union import torch from pytorch_lightning import Callback @@ -26,7 +26,7 @@ if _TEXT_AVAILABLE: from transformers import AutoModelForSequenceClassification - from transformers.modeling_outputs import SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput + from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput class TextClassifier(ClassificationTask): diff --git a/flash/text/ort_callback.py b/flash/text/ort_callback.py index afb8d9695e..080d466fe8 100644 --- a/flash/text/ort_callback.py +++ b/flash/text/ort_callback.py @@ -13,6 +13,7 @@ # limitations under the License. from pytorch_lightning import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException + from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE if _TORCH_ORT_AVAILABLE: @@ -20,8 +21,7 @@ class ORTCallback(Callback): - """ - Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. + """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for training and inference. @@ -33,7 +33,6 @@ class ORTCallback(Callback): # or via the trainer trainer = flash.Trainer(callbacks=ORTCallback()) - """ def __init__(self): @@ -43,7 +42,7 @@ def __init__(self): ) def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not hasattr(pl_module, 'model'): + if not hasattr(pl_module, "model"): raise MisconfigurationException( "Torch ORT requires to wrap a single model that defines a forward function " "assigned as `model` inside the `LightningModule`." diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index fb12140cdd..d79ca18a78 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -16,7 +16,6 @@ from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union import torch -from flash.text.ort_callback import ORTCallback from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_info from torch import Tensor @@ -25,6 +24,7 @@ from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.text.ort_callback import ORTCallback from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings if _TEXT_AVAILABLE: @@ -70,7 +70,7 @@ def __init__( learning_rate: float = 5e-5, val_target_max_length: Optional[int] = None, num_beams: Optional[int] = None, - enable_ort: bool = False + enable_ort: bool = False, ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings diff --git a/flash/text/seq2seq/question_answering/model.py b/flash/text/seq2seq/question_answering/model.py index ae6ea178d0..c7aa2158af 100644 --- a/flash/text/seq2seq/question_answering/model.py +++ b/flash/text/seq2seq/question_answering/model.py @@ -56,7 +56,7 @@ def __init__( num_beams: Optional[int] = 4, use_stemmer: bool = True, rouge_newline_sep: bool = True, - enable_ort: bool = True + enable_ort: bool = True, ): self.save_hyperparameters() super().__init__( @@ -67,7 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, - enable_ort=enable_ort + enable_ort=enable_ort, ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 86be9d579a..3ad9fea350 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -56,7 +56,7 @@ def __init__( num_beams: Optional[int] = 4, use_stemmer: bool = True, rouge_newline_sep: bool = True, - enable_ort: bool = True + enable_ort: bool = True, ): self.save_hyperparameters() super().__init__( @@ -67,7 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, - enable_ort=enable_ort + enable_ort=enable_ort, ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 49fd032bf2..5172b659ae 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -56,7 +56,7 @@ def __init__( num_beams: Optional[int] = 4, n_gram: bool = 4, smooth: bool = True, - enable_ort: bool = True + enable_ort: bool = True, ): self.save_hyperparameters() super().__init__( @@ -67,7 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, - enable_ort=enable_ort + enable_ort=enable_ort, ) self.bleu = BLEUScore( n_gram=n_gram, From 5f8d3f3fbed1cf329804ae25da0c03003e4942fb Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 11:41:17 +0100 Subject: [PATCH 03/19] Fix import --- flash/text/classification/model.py | 4 +- flash_examples/text_classification.py | 111 +++++++++++++++++++------- 2 files changed, 82 insertions(+), 33 deletions(-) diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 92bbf09f33..b2e32406ec 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -26,7 +26,7 @@ if _TEXT_AVAILABLE: from transformers import AutoModelForSequenceClassification - from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput + from transformers.modeling_outputs import SequenceClassifierOutput class TextClassifier(ClassificationTask): @@ -89,7 +89,7 @@ def backbone(self): return self.model.base_model @staticmethod - def apply_filtering(y: torch.Tensor, y_hat: Seq2SeqSequenceClassifierOutput) -> Tuple[torch.Tensor, torch.Tensor]: + def apply_filtering(y: torch.Tensor, y_hat: SequenceClassifierOutput) -> Tuple[torch.Tensor, torch.Tensor]: return y, y_hat.logits def forward(self, batch: Dict[str, torch.Tensor]): diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 3d62dbb0dc..ef376503bd 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -11,39 +11,88 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +import time +from typing import Any import flash +import psutil +import torch from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier +from pytorch_lightning import Callback +from pytorch_lightning.plugins import DeepSpeedPlugin +from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.types import STEP_OUTPUT + + +class CUDACallback(Callback): + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + if batch_idx == 1: + # only start at the second batch + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.root_gpu) + torch.cuda.synchronize(trainer.root_gpu) + self.start_time = time.time() + + def on_batch_end(self, trainer, pl_module) -> None: + torch.cuda.synchronize(trainer.root_gpu) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 + pl_module.log('Peak Memory (GiB)', max_memory / 1000, prog_bar=True, on_step=True, sync_dist=True) + + def on_train_epoch_end(self, trainer, pl_module, outputs): + torch.cuda.synchronize(trainer.root_gpu) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 + epoch_time = time.time() - self.start_time + virt_mem = psutil.virtual_memory() + virt_mem = round((virt_mem.used / (1024 ** 3)), 2) + swap = psutil.swap_memory() + swap = round((swap.used / (1024 ** 3)), 2) + + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + virt_mem = trainer.training_type_plugin.reduce(virt_mem) + swap = trainer.training_type_plugin.reduce(swap) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak CUDA memory {max_memory:.2f} MiB") + rank_zero_info(f"Average Peak Virtual memory {virt_mem:.2f} GiB") + rank_zero_info(f"Average Peak Swap memory {swap:.2f} Gib") + + +if __name__ == '__main__': + # 1. Create the DataModule + download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") + + datamodule = TextClassificationData.from_csv( + "review", + "sentiment", + train_file="data/imdb/train.csv", + val_file="data/imdb/valid.csv", + backbone="facebook/bart-large", + batch_size=4, + ) + + # 2. Build the task + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=False) -# 1. Create the DataModule -download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") - -datamodule = TextClassificationData.from_csv( - "review", - "sentiment", - train_file="data/imdb/train.csv", - val_file="data/imdb/valid.csv", - backbone="prajjwal1/bert-medium", -) - -# 2. Build the task -model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes) - -# 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) -trainer.finetune(model, datamodule=datamodule, strategy="freeze") - -# 4. Classify a few sentences! How was the movie? -predictions = model.predict( - [ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado.", - ] -) -print(predictions) - -# 5. Save the model! -trainer.save_checkpoint("text_classification_model.pt") + # 3. Create the trainer and finetune the model + trainer = flash.Trainer( + max_epochs=1, + plugins=DeepSpeedPlugin(stage=1), + callbacks=CUDACallback(), + precision=16, + accelerator='ddp', + gpus=4, + limit_val_batches=0, + limit_test_batches=0 + ) + trainer.fit(model, datamodule=datamodule) From 97ae954351875a53c2d5b957a5829a2c0142eb96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Aug 2021 10:42:12 +0000 Subject: [PATCH 04/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash_examples/text_classification.py | 30 +++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index ef376503bd..896f29b13f 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -14,27 +14,27 @@ import time from typing import Any -import flash import psutil import torch -from flash.core.data.utils import download_data -from flash.text import TextClassificationData, TextClassifier from pytorch_lightning import Callback from pytorch_lightning.plugins import DeepSpeedPlugin from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.types import STEP_OUTPUT +import flash +from flash.core.data.utils import download_data +from flash.text import TextClassificationData, TextClassifier + class CUDACallback(Callback): - def on_train_batch_end( - self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: STEP_OUTPUT, - batch: Any, - batch_idx: int, - dataloader_idx: int, + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int, ) -> None: if batch_idx == 1: # only start at the second batch @@ -46,7 +46,7 @@ def on_train_batch_end( def on_batch_end(self, trainer, pl_module) -> None: torch.cuda.synchronize(trainer.root_gpu) max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 - pl_module.log('Peak Memory (GiB)', max_memory / 1000, prog_bar=True, on_step=True, sync_dist=True) + pl_module.log("Peak Memory (GiB)", max_memory / 1000, prog_bar=True, on_step=True, sync_dist=True) def on_train_epoch_end(self, trainer, pl_module, outputs): torch.cuda.synchronize(trainer.root_gpu) @@ -68,7 +68,7 @@ def on_train_epoch_end(self, trainer, pl_module, outputs): rank_zero_info(f"Average Peak Swap memory {swap:.2f} Gib") -if __name__ == '__main__': +if __name__ == "__main__": # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") @@ -90,9 +90,9 @@ def on_train_epoch_end(self, trainer, pl_module, outputs): plugins=DeepSpeedPlugin(stage=1), callbacks=CUDACallback(), precision=16, - accelerator='ddp', + accelerator="ddp", gpus=4, limit_val_batches=0, - limit_test_batches=0 + limit_test_batches=0, ) trainer.fit(model, datamodule=datamodule) From 8afc41bed020b14b5e47d8dbe79ac7136e29dadf Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 11:57:36 +0100 Subject: [PATCH 05/19] Update transformers version --- requirements/datatype_text.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index 9953e12545..f6df909705 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -1,5 +1,5 @@ rouge-score>=0.0.4 sentencepiece>=0.1.95 filelock -transformers>=4.5 +transformers>=4.7 datasets>=1.2, <1.3 From efb976b5c5375741b9f9099651a5dce06653fd00 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 12:02:51 +0100 Subject: [PATCH 06/19] Revert --- flash/text/ort_callback.py | 5 +- flash_examples/text_classification.py | 109 +++++++------------------- 2 files changed, 33 insertions(+), 81 deletions(-) diff --git a/flash/text/ort_callback.py b/flash/text/ort_callback.py index 080d466fe8..b3d1a615a3 100644 --- a/flash/text/ort_callback.py +++ b/flash/text/ort_callback.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import Callback +from pytorch_lightning import Callback, LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException +from flash import Trainer from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE if _TORCH_ORT_AVAILABLE: @@ -41,7 +42,7 @@ def __init__(self): "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort" ) - def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None: if not hasattr(pl_module, "model"): raise MisconfigurationException( "Torch ORT requires to wrap a single model that defines a forward function " diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 896f29b13f..6448eeb814 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -11,88 +11,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import time -from typing import Any - -import psutil import torch -from pytorch_lightning import Callback -from pytorch_lightning.plugins import DeepSpeedPlugin -from pytorch_lightning.utilities import rank_zero_info -from pytorch_lightning.utilities.types import STEP_OUTPUT import flash from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier - -class CUDACallback(Callback): - def on_train_batch_end( - self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: STEP_OUTPUT, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - if batch_idx == 1: - # only start at the second batch - # Reset the memory use counter - torch.cuda.reset_peak_memory_stats(trainer.root_gpu) - torch.cuda.synchronize(trainer.root_gpu) - self.start_time = time.time() - - def on_batch_end(self, trainer, pl_module) -> None: - torch.cuda.synchronize(trainer.root_gpu) - max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 - pl_module.log("Peak Memory (GiB)", max_memory / 1000, prog_bar=True, on_step=True, sync_dist=True) - - def on_train_epoch_end(self, trainer, pl_module, outputs): - torch.cuda.synchronize(trainer.root_gpu) - max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 - epoch_time = time.time() - self.start_time - virt_mem = psutil.virtual_memory() - virt_mem = round((virt_mem.used / (1024 ** 3)), 2) - swap = psutil.swap_memory() - swap = round((swap.used / (1024 ** 3)), 2) - - max_memory = trainer.training_type_plugin.reduce(max_memory) - epoch_time = trainer.training_type_plugin.reduce(epoch_time) - virt_mem = trainer.training_type_plugin.reduce(virt_mem) - swap = trainer.training_type_plugin.reduce(swap) - - rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") - rank_zero_info(f"Average Peak CUDA memory {max_memory:.2f} MiB") - rank_zero_info(f"Average Peak Virtual memory {virt_mem:.2f} GiB") - rank_zero_info(f"Average Peak Swap memory {swap:.2f} Gib") - - -if __name__ == "__main__": - # 1. Create the DataModule - download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") - - datamodule = TextClassificationData.from_csv( - "review", - "sentiment", - train_file="data/imdb/train.csv", - val_file="data/imdb/valid.csv", - backbone="facebook/bart-large", - batch_size=4, - ) - - # 2. Build the task - model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=False) - - # 3. Create the trainer and finetune the model - trainer = flash.Trainer( - max_epochs=1, - plugins=DeepSpeedPlugin(stage=1), - callbacks=CUDACallback(), - precision=16, - accelerator="ddp", - gpus=4, - limit_val_batches=0, - limit_test_batches=0, - ) - trainer.fit(model, datamodule=datamodule) +# 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") + +datamodule = TextClassificationData.from_csv( + "review", + "sentiment", + train_file="data/imdb/train.csv", + val_file="data/imdb/valid.csv", + backbone="prajjwal1/bert-medium", +) + +# 2. Build the task +model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Classify a few sentences! How was the movie? +predictions = model.predict( + [ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", + ] +) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("text_classification_model.pt") \ No newline at end of file From 88feeac6f497c33faabc7b86760e55f0ae4c2edd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Aug 2021 11:03:26 +0000 Subject: [PATCH 07/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash_examples/text_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 6448eeb814..3d62dbb0dc 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -46,4 +46,4 @@ print(predictions) # 5. Save the model! -trainer.save_checkpoint("text_classification_model.pt") \ No newline at end of file +trainer.save_checkpoint("text_classification_model.pt") From 749bc0cc02ce5f3fdfd8464c328c856bf26fa5dd Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 13:18:11 +0100 Subject: [PATCH 08/19] Revert --- flash/text/classification/model.py | 10 +++------- requirements/datatype_text.txt | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index b2e32406ec..b7102c0507 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -26,7 +26,7 @@ if _TEXT_AVAILABLE: from transformers import AutoModelForSequenceClassification - from transformers.modeling_outputs import SequenceClassifierOutput + from transformers.modeling_outputs import SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput class TextClassifier(ClassificationTask): @@ -88,20 +88,16 @@ def __init__( def backbone(self): return self.model.base_model - @staticmethod - def apply_filtering(y: torch.Tensor, y_hat: SequenceClassifierOutput) -> Tuple[torch.Tensor, torch.Tensor]: - return y, y_hat.logits - def forward(self, batch: Dict[str, torch.Tensor]): return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None)) def to_loss_format(self, x) -> torch.Tensor: - if isinstance(x, SequenceClassifierOutput): + if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)): x = x.logits return super().to_loss_format(x) def to_metrics_format(self, x) -> torch.Tensor: - if isinstance(x, SequenceClassifierOutput): + if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)): x = x.logits return super().to_metrics_format(x) diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index f6df909705..9953e12545 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -1,5 +1,5 @@ rouge-score>=0.0.4 sentencepiece>=0.1.95 filelock -transformers>=4.7 +transformers>=4.5 datasets>=1.2, <1.3 From aa46d93945f40ec4d8f8f53e190efa289e4d4e8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Aug 2021 12:19:45 +0000 Subject: [PATCH 09/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/text/classification/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index b7102c0507..1092686760 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -26,7 +26,7 @@ if _TEXT_AVAILABLE: from transformers import AutoModelForSequenceClassification - from transformers.modeling_outputs import SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput + from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput class TextClassifier(ClassificationTask): From e4f11048e7fba9e432267e2627c744aef9c516d5 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 14:05:20 +0100 Subject: [PATCH 10/19] Add tests --- flash_examples/text_classification.py | 111 +++++++++++++++++++------- 1 file changed, 81 insertions(+), 30 deletions(-) diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 3d62dbb0dc..e20912a572 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -11,39 +11,90 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time +from typing import Any + +import psutil import torch +from pytorch_lightning import Callback +from pytorch_lightning.plugins import DeepSpeedPlugin +from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.types import STEP_OUTPUT import flash from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier -# 1. Create the DataModule -download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") - -datamodule = TextClassificationData.from_csv( - "review", - "sentiment", - train_file="data/imdb/train.csv", - val_file="data/imdb/valid.csv", - backbone="prajjwal1/bert-medium", -) - -# 2. Build the task -model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes) - -# 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) -trainer.finetune(model, datamodule=datamodule, strategy="freeze") - -# 4. Classify a few sentences! How was the movie? -predictions = model.predict( - [ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado.", - ] -) -print(predictions) - -# 5. Save the model! -trainer.save_checkpoint("text_classification_model.pt") + +class CUDACallback(Callback): + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + if batch_idx == 1: + # only start at the second batch + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.root_gpu) + torch.cuda.synchronize(trainer.root_gpu) + self.start_time = time.time() + + def on_batch_end(self, trainer, pl_module) -> None: + torch.cuda.synchronize(trainer.root_gpu) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 + pl_module.log("Peak Memory (GiB)", max_memory / 1000, prog_bar=True, on_step=True, sync_dist=True) + + def on_train_epoch_end(self, trainer, pl_module, outputs): + torch.cuda.synchronize(trainer.root_gpu) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 + epoch_time = time.time() - self.start_time + virt_mem = psutil.virtual_memory() + virt_mem = round((virt_mem.used / (1024 ** 3)), 2) + swap = psutil.swap_memory() + swap = round((swap.used / (1024 ** 3)), 2) + + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + virt_mem = trainer.training_type_plugin.reduce(virt_mem) + swap = trainer.training_type_plugin.reduce(swap) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak CUDA memory {max_memory:.2f} MiB") + rank_zero_info(f"Average Peak Virtual memory {virt_mem:.2f} GiB") + rank_zero_info(f"Average Peak Swap memory {swap:.2f} Gib") + + +if __name__ == "__main__": + # 1. Create the DataModule + download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") + + datamodule = TextClassificationData.from_csv( + "review", + "sentiment", + train_file="data/imdb/train.csv", + val_file="data/imdb/valid.csv", + backbone="facebook/bart-large", + batch_size=4, + ) + + # 2. Build the task + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=False) + + # 3. Create the trainer and finetune the model + trainer = flash.Trainer( + max_epochs=1, + plugins=DeepSpeedPlugin(stage=1), + callbacks=CUDACallback(), + precision=16, + accelerator="ddp", + gpus=1, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + ) + trainer.fit(model, datamodule=datamodule) + trainer.test(model) From 85cdc8acec7d68725a10a66a1c831f909cfd5cc3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 14:07:01 +0100 Subject: [PATCH 11/19] Add tests --- flash/text/classification/model.py | 2 +- flash_examples/text_classification.py | 111 +++++++------------------- tests/text/classification/test_ort.py | 65 +++++++++++++++ 3 files changed, 96 insertions(+), 82 deletions(-) create mode 100644 tests/text/classification/test_ort.py diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 1092686760..cf339153a0 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -13,7 +13,7 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch from pytorch_lightning import Callback diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index e20912a572..6448eeb814 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -11,90 +11,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import time -from typing import Any - -import psutil import torch -from pytorch_lightning import Callback -from pytorch_lightning.plugins import DeepSpeedPlugin -from pytorch_lightning.utilities import rank_zero_info -from pytorch_lightning.utilities.types import STEP_OUTPUT import flash from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier - -class CUDACallback(Callback): - def on_train_batch_end( - self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: STEP_OUTPUT, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - if batch_idx == 1: - # only start at the second batch - # Reset the memory use counter - torch.cuda.reset_peak_memory_stats(trainer.root_gpu) - torch.cuda.synchronize(trainer.root_gpu) - self.start_time = time.time() - - def on_batch_end(self, trainer, pl_module) -> None: - torch.cuda.synchronize(trainer.root_gpu) - max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 - pl_module.log("Peak Memory (GiB)", max_memory / 1000, prog_bar=True, on_step=True, sync_dist=True) - - def on_train_epoch_end(self, trainer, pl_module, outputs): - torch.cuda.synchronize(trainer.root_gpu) - max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 - epoch_time = time.time() - self.start_time - virt_mem = psutil.virtual_memory() - virt_mem = round((virt_mem.used / (1024 ** 3)), 2) - swap = psutil.swap_memory() - swap = round((swap.used / (1024 ** 3)), 2) - - max_memory = trainer.training_type_plugin.reduce(max_memory) - epoch_time = trainer.training_type_plugin.reduce(epoch_time) - virt_mem = trainer.training_type_plugin.reduce(virt_mem) - swap = trainer.training_type_plugin.reduce(swap) - - rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") - rank_zero_info(f"Average Peak CUDA memory {max_memory:.2f} MiB") - rank_zero_info(f"Average Peak Virtual memory {virt_mem:.2f} GiB") - rank_zero_info(f"Average Peak Swap memory {swap:.2f} Gib") - - -if __name__ == "__main__": - # 1. Create the DataModule - download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") - - datamodule = TextClassificationData.from_csv( - "review", - "sentiment", - train_file="data/imdb/train.csv", - val_file="data/imdb/valid.csv", - backbone="facebook/bart-large", - batch_size=4, - ) - - # 2. Build the task - model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=False) - - # 3. Create the trainer and finetune the model - trainer = flash.Trainer( - max_epochs=1, - plugins=DeepSpeedPlugin(stage=1), - callbacks=CUDACallback(), - precision=16, - accelerator="ddp", - gpus=1, - limit_train_batches=10, - limit_val_batches=10, - limit_test_batches=10, - ) - trainer.fit(model, datamodule=datamodule) - trainer.test(model) +# 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") + +datamodule = TextClassificationData.from_csv( + "review", + "sentiment", + train_file="data/imdb/train.csv", + val_file="data/imdb/valid.csv", + backbone="prajjwal1/bert-medium", +) + +# 2. Build the task +model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Classify a few sentences! How was the movie? +predictions = model.predict( + [ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", + ] +) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("text_classification_model.pt") \ No newline at end of file diff --git a/tests/text/classification/test_ort.py b/tests/text/classification/test_ort.py new file mode 100644 index 0000000000..b9ca9b9ca7 --- /dev/null +++ b/tests/text/classification/test_ort.py @@ -0,0 +1,65 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch +from pytorch_lightning import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash import Trainer +from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE +from flash.text import TextClassifier +from flash.text.ort_callback import ORTCallback +from tests.helpers.boring_model import BoringModel +from tests.helpers.utils import _TEXT_TESTING +from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE + +if _TORCH_ORT_AVAILABLE: + from torch_ort import ORTModule + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") +def test_init_train_enable_ort(tmpdir): + class TestCallback(Callback): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert isinstance(pl_module.model, ORTModule) + + model = TextClassifier(2, TEST_BACKBONE, enable_ort=True) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback()) + trainer.fit( + model, + train_dataloader=torch.utils.data.DataLoader(DummyDataset()), + val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), + ) + trainer.test( + model, + test_dataloaders=torch.utils.data.DataLoader(DummyDataset()) + ) + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") +def test_ort_callback_fails_no_model(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback()) + with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"): + trainer.fit( + model, + train_dataloader=torch.utils.data.DataLoader(DummyDataset()), + val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), + ) From 928df0cc10bcf6e9793979e0d9bff944f40ecb13 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 14:16:50 +0100 Subject: [PATCH 12/19] fix --- flash_examples/text_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 6448eeb814..3d62dbb0dc 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -46,4 +46,4 @@ print(predictions) # 5. Save the model! -trainer.save_checkpoint("text_classification_model.pt") \ No newline at end of file +trainer.save_checkpoint("text_classification_model.pt") From cd336c0b1d4615b22b4100cba89a4c0602ed2a05 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Aug 2021 13:29:55 +0000 Subject: [PATCH 13/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/text/classification/test_ort.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/text/classification/test_ort.py b/tests/text/classification/test_ort.py index b9ca9b9ca7..01d987e092 100644 --- a/tests/text/classification/test_ort.py +++ b/tests/text/classification/test_ort.py @@ -46,10 +46,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: train_dataloader=torch.utils.data.DataLoader(DummyDataset()), val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), ) - trainer.test( - model, - test_dataloaders=torch.utils.data.DataLoader(DummyDataset()) - ) + trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset())) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") From 117f8c4253aab9a43444c9254a5006ce86caa419 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 14:41:59 +0100 Subject: [PATCH 14/19] Add docs for text classification and translation --- docs/source/reference/text_classification.rst | 14 ++++++++++++++ docs/source/reference/translation.rst | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index 42424cc980..089548ded9 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -85,3 +85,17 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/text_classification/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. + +.. code-block:: python + + ... + + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index 939e3f544a..88e8098c23 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -85,3 +85,17 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/translation/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. + +.. code-block:: python + + ... + + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) From 12501e1aee78526a8d28f5ef3c431c18f23016d3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 14:45:12 +0100 Subject: [PATCH 15/19] Add note --- docs/source/reference/text_classification.rst | 4 ++++ docs/source/reference/translation.rst | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index 089548ded9..2728decf10 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -94,6 +94,10 @@ Accelerate Training & Inference with Torch ORT `Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. +.. note:: + + Not all Transformer models are supported. See `this table `__ for support models + branches containing fixes for certain models. + .. code-block:: python ... diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index 88e8098c23..2e8a5f70ff 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -94,6 +94,10 @@ Accelerate Training & Inference with Torch ORT `Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. +.. note:: + + Not all Transformer models are supported. See `this table `__ for support models + branches containing fixes for certain models. + .. code-block:: python ... From 5c5d55a53e96eee999d56ed9cda2d3f896683e44 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 14:47:05 +0100 Subject: [PATCH 16/19] Add CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22bd7058ba..b5c9ec4dd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) +- Added Torch ORT support to Transformer based tasks ([#667](https://github.com/PyTorchLightning/lightning-flash/pull/667)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) From 0c1ecc44ec8166b76a889cfac380af37c82a51c1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 17 Aug 2021 16:12:36 +0100 Subject: [PATCH 17/19] Address code review --- docs/source/reference/summarization.rst | 18 ++++++++++++++++++ docs/source/reference/text_classification.rst | 4 ++-- docs/source/reference/translation.rst | 6 +++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst index ff7bedf4bc..ec3378d8f8 100644 --- a/docs/source/reference/summarization.rst +++ b/docs/source/reference/summarization.rst @@ -85,3 +85,21 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/summarization/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``SummarizationTask`` once installed. See installation instructions `here `__. + +.. note:: + + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. + +.. code-block:: python + + ... + + model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True) \ No newline at end of file diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index 2728decf10..989ce2e387 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -92,11 +92,11 @@ You can now perform inference from your client like this: Accelerate Training & Inference with Torch ORT ********************************************** -`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. .. note:: - Not all Transformer models are supported. See `this table `__ for support models + branches containing fixes for certain models. + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. .. code-block:: python diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index 2e8a5f70ff..cc7c21c517 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -92,14 +92,14 @@ You can now perform inference from your client like this: Accelerate Training & Inference with Torch ORT ********************************************** -`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TranslationTask`` once installed. See installation instructions `here `__. .. note:: - Not all Transformer models are supported. See `this table `__ for support models + branches containing fixes for certain models. + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. .. code-block:: python ... - model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) + model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True) From 7d035832f6cf580eef1f7f8fef7e9a36fdce8fdd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Aug 2021 15:14:13 +0000 Subject: [PATCH 18/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/reference/summarization.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst index ec3378d8f8..6010324cb1 100644 --- a/docs/source/reference/summarization.rst +++ b/docs/source/reference/summarization.rst @@ -102,4 +102,4 @@ Accelerate Training & Inference with Torch ORT ... - model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True) \ No newline at end of file + model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True) From f1723e7a6635e56d7e47f6e647e4233d577c480b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 17 Aug 2021 17:05:45 +0100 Subject: [PATCH 19/19] Apply suggestions from code review --- flash/text/seq2seq/question_answering/model.py | 2 +- flash/text/seq2seq/summarization/model.py | 2 +- flash/text/seq2seq/translation/model.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/text/seq2seq/question_answering/model.py b/flash/text/seq2seq/question_answering/model.py index c7aa2158af..0ebec8aed3 100644 --- a/flash/text/seq2seq/question_answering/model.py +++ b/flash/text/seq2seq/question_answering/model.py @@ -56,7 +56,7 @@ def __init__( num_beams: Optional[int] = 4, use_stemmer: bool = True, rouge_newline_sep: bool = True, - enable_ort: bool = True, + enable_ort: bool = False, ): self.save_hyperparameters() super().__init__( diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 3ad9fea350..19e812baf1 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -56,7 +56,7 @@ def __init__( num_beams: Optional[int] = 4, use_stemmer: bool = True, rouge_newline_sep: bool = True, - enable_ort: bool = True, + enable_ort: bool = False, ): self.save_hyperparameters() super().__init__( diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 5172b659ae..c70089e8d6 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -56,7 +56,7 @@ def __init__( num_beams: Optional[int] = 4, n_gram: bool = 4, smooth: bool = True, - enable_ort: bool = True, + enable_ort: bool = False, ): self.save_hyperparameters() super().__init__(