Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add support for Torch ORT to Transformer based Tasks #667

Merged
merged 22 commits into from
Aug 17, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
pre-commit-ci[bot] committed Aug 17, 2021
commit 97ae954351875a53c2d5b957a5829a2c0142eb96
30 changes: 15 additions & 15 deletions flash_examples/text_classification.py
Original file line number Diff line number Diff line change
@@ -14,27 +14,27 @@
import time
from typing import Any

import flash
import psutil
import torch
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
from pytorch_lightning import Callback
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.types import STEP_OUTPUT

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier


class CUDACallback(Callback):

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
if batch_idx == 1:
# only start at the second batch
@@ -46,7 +46,7 @@ def on_train_batch_end(
def on_batch_end(self, trainer, pl_module) -> None:
torch.cuda.synchronize(trainer.root_gpu)
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
pl_module.log('Peak Memory (GiB)', max_memory / 1000, prog_bar=True, on_step=True, sync_dist=True)
pl_module.log("Peak Memory (GiB)", max_memory / 1000, prog_bar=True, on_step=True, sync_dist=True)

def on_train_epoch_end(self, trainer, pl_module, outputs):
torch.cuda.synchronize(trainer.root_gpu)
@@ -68,7 +68,7 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
rank_zero_info(f"Average Peak Swap memory {swap:.2f} Gib")


if __name__ == '__main__':
if __name__ == "__main__":
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")

@@ -90,9 +90,9 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
plugins=DeepSpeedPlugin(stage=1),
callbacks=CUDACallback(),
precision=16,
accelerator='ddp',
accelerator="ddp",
gpus=4,
limit_val_batches=0,
limit_test_batches=0
limit_test_batches=0,
)
trainer.fit(model, datamodule=datamodule)