From 349c88c2f1cd06a1c3a641a970de0c14fc5f4c0c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 21 Apr 2021 17:30:38 +0100 Subject: [PATCH 1/2] Docs fixes --- README.md | 10 +++++----- docs/source/general/finetuning.rst | 2 +- docs/source/general/predictions.rst | 8 +++----- docs/source/general/training.rst | 6 +++--- docs/source/quickstart.rst | 6 +++--- docs/source/reference/image_classification.rst | 4 ++-- docs/source/reference/image_embedder.rst | 2 +- docs/source/reference/object_detection.rst | 2 +- docs/source/reference/summarization.rst | 4 ++-- docs/source/reference/tabular_classification.rst | 6 +++--- docs/source/reference/text_classification.rst | 4 ++-- docs/source/reference/translation.rst | 4 ++-- flash_examples/finetuning/summarization.py | 6 +++--- flash_examples/finetuning/translation.py | 2 +- flash_examples/predict/translation.py | 2 +- 15 files changed, 33 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 0609434469..baa31d2422 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ First, finetune: ```python # import our libraries import flash -from flash import download_data +from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data @@ -170,7 +170,7 @@ Flash has an Image embedding task to encodes an image into a vector of image fea View example ```python -from flash.core.data import download_data +from flash.data.utils import download_data from flash.vision import ImageEmbedder # 1. Download the data @@ -197,7 +197,7 @@ Flash has a Summarization task to sum up text from a larger article into a short ```python # import our libraries import flash -from flash import download_data +from flash.data.utils import download_data from flash.text import SummarizationData, SummarizationTask # 1. Download the data @@ -244,7 +244,7 @@ To illustrate, say we want to build a model to predict if a passenger survived o # import our libraries from torchmetrics.classification import Accuracy, Precision, Recall import flash -from flash import download_data +from flash.data.utils import download_data from flash.tabular import TabularClassifier, TabularData # 1. Download the data @@ -295,7 +295,7 @@ To illustrate, say we want to build a model on a tiny coco dataset. ```python # import our libraries import flash -from flash.core.data import download_data +from flash.data.utils import download_data from flash.vision import ObjectDetectionData, ObjectDetector # 1. Download the data diff --git a/docs/source/general/finetuning.rst b/docs/source/general/finetuning.rst index bdafca6360..a2b1cb95e1 100644 --- a/docs/source/general/finetuning.rst +++ b/docs/source/general/finetuning.rst @@ -52,7 +52,7 @@ Here are the steps in code .. code-block:: python import flash - from flash import download_data + from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. download and organize the data diff --git a/docs/source/general/predictions.rst b/docs/source/general/predictions.rst index a2214c74ba..76ca94511d 100644 --- a/docs/source/general/predictions.rst +++ b/docs/source/general/predictions.rst @@ -12,11 +12,11 @@ Predict on a single sample of data You can pass in a sample of data (image file path, a string of text, etc) to the :func:`~flash.core.model.Task.predict` method. - + .. code-block:: python from flash import Trainer - from flash.core.data import download_data + from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier @@ -37,7 +37,7 @@ Predict on a csv file .. code-block:: python - from flash.core.data import download_data + from flash.data.utils import download_data from flash.tabular import TabularClassifier # 1. Download the data @@ -51,5 +51,3 @@ Predict on a csv file # 3. Generate predictions from a csv file! Who would survive? predictions = model.predict("data/titanic/titanic.csv") print(predictions) - - diff --git a/docs/source/general/training.rst b/docs/source/general/training.rst index 67149bf742..aa1e3ec6ae 100644 --- a/docs/source/general/training.rst +++ b/docs/source/general/training.rst @@ -11,7 +11,7 @@ Some Flash tasks have been pretrained on large data sets. To accelerate your tra .. code-block:: python import flash - from flash import download_data + from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. download and organize the data @@ -48,7 +48,7 @@ Flash tasks supports many advanced training functionalities out-of-the-box, such # train on 1 GPU flash.Trainer(gpus=1) - + * Training on multiple GPUs .. code-block:: python @@ -60,7 +60,7 @@ Flash tasks supports many advanced training functionalities out-of-the-box, such # train on gpu 1, 3, 5 (3 gpus total) flash.Trainer(gpus=[1, 3, 5]) - + * Using mixed precision training .. code-block:: python diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index f30f4ec7eb..1614315f50 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -15,7 +15,7 @@ For getting started with Deep Learning Easy to learn ^^^^^^^^^^^^^ -If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required! +If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required! Easy to scale ^^^^^^^^^^^^^ @@ -70,7 +70,7 @@ You can install flash using pip or conda: Tasks ===== -Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods. +Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods. The Flash tasks contain all the relevant information to solve the task at hand- the number of class labels you want to predict, number of columns in your dataset, as well as details on the model architecture used such as loss function, optimizers, etc. @@ -137,7 +137,7 @@ Here's an example of finetuning. .. code-block:: python import flash - from flash import download_data + from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index c457be23c6..ce5cee4a48 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -25,7 +25,7 @@ Use the :class:`~flash.vision.ImageClassifier` pretrained model for inference on # import our libraries from flash import Trainer - from flash import download_data + from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data @@ -90,7 +90,7 @@ Now all we need is three lines of code to build to train our task! .. code-block:: python import flash - from flash import download_data + from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index f2c2b2b36f..3696af4da6 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -54,7 +54,7 @@ To tailor this image embedder to your dataset, finetune first. .. code-block:: python import flash - from flash import download_data + from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageEmbedder # 1. Download the data diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index bed0c9fd53..d6cd1e1c91 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -75,7 +75,7 @@ To tailor the object detector to your dataset, you would need to have it in `COC .. code-block:: python import flash - from flash.core.data import download_data + from flash.data.utils import download_data from flash.vision import ObjectDetectionData, ObjectDetector # 1. Download the data diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst index 5f47542b2e..bfa389ba9a 100644 --- a/docs/source/reference/summarization.rst +++ b/docs/source/reference/summarization.rst @@ -60,7 +60,7 @@ Or on a given dataset, use :class:`~flash.core.trainer.Trainer` `predict` method # import our libraries from flash import Trainer - from flash import download_data + from flash.data.utils import download_data from flash.text import SummarizationData, SummarizationTask # 1. Download data @@ -104,7 +104,7 @@ All we need is three lines of code to train our model! # import our libraries import flash - from flash import download_data + from flash.data.utils import download_data from flash.text import SummarizationData, SummarizationTask # 1. Download data diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index 8952ffa1eb..e54356c751 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -45,7 +45,7 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da .. code-block:: python import flash - from flash import download_data + from flash.data.utils import download_data from flash.tabular import TabularClassifier, TabularData from torchmetrics.classification import Accuracy, Precision, Recall @@ -92,7 +92,7 @@ You can make predcitions on a pretrained model, that has already been trained fo .. code-block:: python - from flash.core.data import download_data + from flash.data.utils import download_data from flash.tabular import TabularClassifier # 1. Download the data @@ -113,7 +113,7 @@ Or you can finetune your own model and use that for prediction: .. code-block:: python import flash - from flash import download_data + from flash.data.utils import download_data from flash.tabular import TabularClassifier, TabularData # 1. Load the data diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index c68fa925b7..e821af7e0d 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -24,7 +24,7 @@ Use the :class:`~flash.text.classification.model.TextClassifier` pretrained mode from pytorch_lightning import Trainer - from flash import download_data + from flash.data.utils import download_data from flash.text import TextClassificationData, TextClassifier @@ -77,7 +77,7 @@ All we need is three lines of code to train our model! .. code-block:: python import flash - from flash.core.data import download_data + from flash.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index cdcd07db01..05ee877deb 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -42,7 +42,7 @@ Or on a given dataset, use :class:`~flash.core.trainer.Trainer` `predict` method # import our libraries from flash import Trainer - from flash import download_data + from flash.data.utils import download_data from flash.text import TranslationData, TranslationTask # 1. Download data @@ -86,7 +86,7 @@ All we need is three lines of code to train our model! By default, we use a `mBA # import our libraries import flash - from flash import download_data + from flash.data.utils import download_data from flash.text import TranslationData, TranslationTask # 1. Download data diff --git a/flash_examples/finetuning/summarization.py b/flash_examples/finetuning/summarization.py index 08e3f63f4b..d2ecc726f3 100644 --- a/flash_examples/finetuning/summarization.py +++ b/flash_examples/finetuning/summarization.py @@ -13,8 +13,8 @@ # limitations under the License. import torch -import flash -from flash import download_data, Trainer +from flash import Trainer +from flash.data.utils import download_data from flash.text import SummarizationData, SummarizationTask # 1. Download the data @@ -33,7 +33,7 @@ model = SummarizationTask() # 4. Create the trainer. Run once on data -trainer = flash.Trainer(gpus=int(torch.cuda.is_available()), fast_dev_run=True) +trainer = Trainer(gpus=int(torch.cuda.is_available()), fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index fe3e0a3f24..be91ea057d 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -14,7 +14,7 @@ import torch import flash -from flash import download_data +from flash.data.utils import download_data from flash.text import TranslationData, TranslationTask # 1. Download the data diff --git a/flash_examples/predict/translation.py b/flash_examples/predict/translation.py index a956a4af5a..bbf3d42446 100644 --- a/flash_examples/predict/translation.py +++ b/flash_examples/predict/translation.py @@ -13,7 +13,7 @@ # limitations under the License. from pytorch_lightning import Trainer -from flash import download_data +from flash.data.utils import download_data from flash.text import TranslationData, TranslationTask # 1. Download the data From 79f271ed8c37e079fd0284cbf6d67d747147b091 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 21 Apr 2021 20:40:24 +0100 Subject: [PATCH 2/2] [feat] Add support for schedulers (#232) * add support for schedulers * update changelog * resolve typing * update task * change for log softmax * udpate on comments --- CHANGELOG.md | 5 ++- flash/core/model.py | 87 +++++++++++++++++++++++++++++++++++-- flash/core/schedulers.py | 14 ++++++ flash/data/data_pipeline.py | 3 ++ flash/utils/imports.py | 1 + tests/core/test_model.py | 65 +++++++++++++++++++++++++++ 6 files changed, 170 insertions(+), 5 deletions(-) create mode 100644 flash/core/schedulers.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c50aaab6da..627e907654 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Switch to use `torchmetrics` ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169)) +- Better support for `optimizer` and `schedulers` ([#232](https://github.com/PyTorchLightning/lightning-flash/pull/232)) + + ### Fixed @@ -28,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `RetinaNet` & `backbones` to `ObjectDetector` Task ([#121](https://github.com/PyTorchLightning/lightning-flash/pull/121)) -- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116), +- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116), [#117](https://github.com/PyTorchLightning/lightning-flash/pull/117), [#118](https://github.com/PyTorchLightning/lightning-flash/pull/118)) diff --git a/flash/core/model.py b/flash/core/model.py index 9914b4cb61..02dc367932 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -21,9 +21,13 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer from flash.core.registry import FlashRegistry +from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -64,11 +68,16 @@ class Task(LightningModule): postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task. """ + schedulers: FlashRegistry = _SCHEDULERS_REGISTRY + def __init__( self, model: Optional[nn.Module] = None, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, preprocess: Preprocess = None, @@ -78,7 +87,11 @@ def __init__( if model is not None: self.model = model self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) - self.optimizer_cls = optimizer + self.optimizer = optimizer + self.scheduler = scheduler + self.optimizer_kwargs = optimizer_kwargs or {} + self.scheduler_kwargs = scheduler_kwargs or {} + self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics @@ -168,8 +181,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A batch = torch.stack(batch) return self(batch) - def configure_optimizers(self) -> torch.optim.Optimizer: - return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) + def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: + optimizer = self.optimizer + if not isinstance(self.optimizer, Optimizer): + self.optimizer_kwargs["lr"] = self.learning_rate + optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs) + if self.scheduler: + return [optimizer], [self._instantiate_scheduler(optimizer)] + return optimizer def configure_finetune_callback(self) -> List[Callback]: return [] @@ -323,3 +342,63 @@ def available_models(cls) -> List[str]: if registry is None: return [] return registry.available_keys() + + @classmethod + def available_schedulers(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None) + if registry is None: + return [] + return registry.available_keys() + + def get_num_training_steps(self) -> int: + """Total training steps inferred from datamodule and devices.""" + if not getattr(self, "trainer", None): + raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.") + if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0: + dataset_size = self.trainer.limit_train_batches + elif isinstance(self.trainer.limit_train_batches, float): + # limit_train_batches is a percentage of batches + dataset_size = len(self.train_dataloader()) + dataset_size = int(dataset_size * self.trainer.limit_train_batches) + else: + dataset_size = len(self.train_dataloader()) + + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) + if self.trainer.tpu_cores: + num_devices = max(num_devices, self.trainer.tpu_cores) + + effective_batch_size = self.trainer.accumulate_grad_batches * num_devices + max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs + + if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps: + return self.trainer.max_steps + return max_estimated_steps + + def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int: + if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0): + raise MisconfigurationException( + "`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`" + ) + if isinstance(num_warmup_steps, float): + # Convert float values to percentage of training steps to use as warmup + num_warmup_steps *= num_training_steps + return round(num_warmup_steps) + + def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: + scheduler = self.scheduler + if isinstance(scheduler, _LRScheduler): + return scheduler + if isinstance(scheduler, str): + scheduler_fn = self.schedulers.get(self.scheduler) + num_training_steps: int = self.get_num_training_steps() + num_warmup_steps: int = self._compute_warmup( + num_training_steps=num_training_steps, + num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), + ) + return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) + elif issubclass(scheduler, _LRScheduler): + return scheduler(optimizer, **self.scheduler_kwargs) + raise MisconfigurationException( + "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " + f"or a built-in scheduler in {self.available_schedulers()}" + ) diff --git a/flash/core/schedulers.py b/flash/core/schedulers.py new file mode 100644 index 0000000000..eee60cc8f8 --- /dev/null +++ b/flash/core/schedulers.py @@ -0,0 +1,14 @@ +from typing import Callable, List + +from flash.core.registry import FlashRegistry +from flash.utils.imports import _TRANSFORMERS_AVAILABLE + +_SCHEDULERS_REGISTRY = FlashRegistry("scheduler") + +if _TRANSFORMERS_AVAILABLE: + from transformers import optimization + functions: List[Callable] = [ + getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler') + ] + for fn in functions: + _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:]) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 46be3c823c..fe75404f1c 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -16,6 +16,7 @@ import weakref from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union +import torch from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import imports @@ -285,6 +286,8 @@ def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None def _attach_preprocess_to_model( self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False ) -> None: + device_collate_fn = torch.nn.Identity() + if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] diff --git a/flash/utils/imports.py b/flash/utils/imports.py index 5e17ba6d3e..5252a3e3d5 100644 --- a/flash/utils/imports.py +++ b/flash/utils/imports.py @@ -5,3 +5,4 @@ _COCO_AVAILABLE = _module_available("pycocotools") _TIMM_AVAILABLE = _module_available("timm") _TORCHVISION_AVAILABLE = _module_available("torchvision") +_TRANSFORMERS_AVAILABLE = _module_available("transformers") diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d0b0048b23..450b662dbd 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -20,12 +20,16 @@ import pytorch_lightning as pl import torch from PIL import Image +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F +from torch.utils.data import DataLoader +import flash from flash.core.classification import ClassificationTask from flash.tabular import TabularClassifier from flash.text import SummarizationTask, TextClassifier +from flash.utils.imports import _TRANSFORMERS_AVAILABLE from flash.vision import ImageClassificationData, ImageClassifier # ======== Mock functions ======== @@ -160,3 +164,64 @@ class Foo(ImageClassifier): backbones = None assert Foo.available_backbones() == [] + + +def test_optimization(tmpdir): + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + optim = torch.optim.Adam(model.parameters()) + task = ClassificationTask(model, optimizer=optim, scheduler=None) + + optimizer = task.configure_optimizers() + assert optimizer == optim + + task = ClassificationTask(model, optimizer=torch.optim.Adadelta, optimizer_kwargs={"eps": 0.5}, scheduler=None) + optimizer = task.configure_optimizers() + assert isinstance(optimizer, torch.optim.Adadelta) + assert optimizer.defaults["eps"] == 0.5 + + task = ClassificationTask( + model, + optimizer=torch.optim.Adadelta, + scheduler=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 1} + ) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + optim = torch.optim.Adadelta(model.parameters()) + task = ClassificationTask(model, optimizer=optim, scheduler=torch.optim.lr_scheduler.StepLR(optim, step_size=1)) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + if _TRANSFORMERS_AVAILABLE: + from transformers.optimization import get_linear_schedule_with_warmup + + assert task.available_schedulers() == [ + 'constant_schedule', 'constant_schedule_with_warmup', 'cosine_schedule_with_warmup', + 'cosine_with_hard_restarts_schedule_with_warmup', 'linear_schedule_with_warmup', + 'polynomial_decay_schedule_with_warmup' + ] + + optim = torch.optim.Adadelta(model.parameters()) + with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): + task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") + optimizer, scheduler = task.configure_optimizers() + + task = ClassificationTask( + model, + optimizer=optim, + scheduler="linear_schedule_with_warmup", + scheduler_kwargs={"num_warmup_steps": 0.1}, + loss_fn=F.nll_loss, + ) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=2) + ds = DummyDataset() + trainer.fit(task, train_dataloader=DataLoader(ds)) + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) + expected = get_linear_schedule_with_warmup.__name__ + assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected