Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add support for schedulers #232

Merged
merged 6 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Switch to use `torchmetrics` ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169))

- Better support for `optimizer` and `schedulers` ([#232](https://github.com/PyTorchLightning/lightning-flash/pull/232))



### Fixed

Expand All @@ -28,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `RetinaNet` & `backbones` to `ObjectDetector` Task ([#121](https://github.com/PyTorchLightning/lightning-flash/pull/121))
- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116),
- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116),
[#117](https://github.com/PyTorchLightning/lightning-flash/pull/117),
[#118](https://github.com/PyTorchLightning/lightning-flash/pull/118))

Expand Down
87 changes: 83 additions & 4 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess

Expand Down Expand Up @@ -64,11 +68,16 @@ class Task(LightningModule):
postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task.
"""

schedulers: FlashRegistry = _SCHEDULERS_REGISTRY

def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
preprocess: Preprocess = None,
Expand All @@ -78,7 +87,11 @@ def __init__(
if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer_cls = optimizer
self.optimizer = optimizer
self.scheduler = scheduler
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}

self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
Expand Down Expand Up @@ -168,8 +181,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
batch = torch.stack(batch)
return self(batch)

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:
optimizer = self.optimizer
if not isinstance(self.optimizer, Optimizer):
self.optimizer_kwargs["lr"] = self.learning_rate
optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs)
if self.scheduler:
return [optimizer], [self._instantiate_scheduler(optimizer)]
return optimizer

def configure_finetune_callback(self) -> List[Callback]:
return []
Expand Down Expand Up @@ -323,3 +342,63 @@ def available_models(cls) -> List[str]:
if registry is None:
return []
return registry.available_keys()

@classmethod
def available_schedulers(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None)
if registry is None:
return []
return registry.available_keys()

def get_num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if not getattr(self, "trainer", None):
raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.")
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(self.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.train_dataloader())

num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe someone else can confirm, but I think num_gpus is automatically set to num_processes now, so this isn't needed, we should be able to do just num_devices = num_processes

if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)

effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs

if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
return self.trainer.max_steps
return max_estimated_steps

def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int:
if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0):
raise MisconfigurationException(
"`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`"
)
if isinstance(num_warmup_steps, float):
# Convert float values to percentage of training steps to use as warmup
num_warmup_steps *= num_training_steps
return round(num_warmup_steps)

def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler:
scheduler = self.scheduler
if isinstance(scheduler, _LRScheduler):
return scheduler
if isinstance(scheduler, str):
scheduler_fn = self.schedulers.get(self.scheduler)
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"),
)
return scheduler_fn(optimizer, num_warmup_steps, num_training_steps)
elif issubclass(scheduler, _LRScheduler):
return scheduler(optimizer, **self.scheduler_kwargs)
raise MisconfigurationException(
"scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` "
f"or a built-in scheduler in {self.available_schedulers()}"
)
14 changes: 14 additions & 0 deletions flash/core/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Callable, List

from flash.core.registry import FlashRegistry
from flash.utils.imports import _TRANSFORMERS_AVAILABLE

_SCHEDULERS_REGISTRY = FlashRegistry("scheduler")

if _TRANSFORMERS_AVAILABLE:
from transformers import optimization
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
functions: List[Callable] = [
getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')
]
for fn in functions:
_SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:])
3 changes: 3 additions & 0 deletions flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import weakref
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import imports
Expand Down Expand Up @@ -285,6 +286,8 @@ def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None
def _attach_preprocess_to_model(
self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False
) -> None:
device_collate_fn = torch.nn.Identity()

if not stage:
stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]

Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
65 changes: 65 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
import pytorch_lightning as pl
import torch
from PIL import Image
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader

import flash
from flash.core.classification import ClassificationTask
from flash.tabular import TabularClassifier
from flash.text import SummarizationTask, TextClassifier
from flash.utils.imports import _TRANSFORMERS_AVAILABLE
from flash.vision import ImageClassificationData, ImageClassifier

# ======== Mock functions ========
Expand Down Expand Up @@ -160,3 +164,64 @@ class Foo(ImageClassifier):
backbones = None

assert Foo.available_backbones() == []


def test_optimization(tmpdir):

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())
optim = torch.optim.Adam(model.parameters())
task = ClassificationTask(model, optimizer=optim, scheduler=None)

optimizer = task.configure_optimizers()
assert optimizer == optim

task = ClassificationTask(model, optimizer=torch.optim.Adadelta, optimizer_kwargs={"eps": 0.5}, scheduler=None)
optimizer = task.configure_optimizers()
assert isinstance(optimizer, torch.optim.Adadelta)
assert optimizer.defaults["eps"] == 0.5

task = ClassificationTask(
model,
optimizer=torch.optim.Adadelta,
scheduler=torch.optim.lr_scheduler.StepLR,
scheduler_kwargs={"step_size": 1}
)
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR)

optim = torch.optim.Adadelta(model.parameters())
task = ClassificationTask(model, optimizer=optim, scheduler=torch.optim.lr_scheduler.StepLR(optim, step_size=1))
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR)

if _TRANSFORMERS_AVAILABLE:
from transformers.optimization import get_linear_schedule_with_warmup

assert task.available_schedulers() == [
'constant_schedule', 'constant_schedule_with_warmup', 'cosine_schedule_with_warmup',
'cosine_with_hard_restarts_schedule_with_warmup', 'linear_schedule_with_warmup',
'polynomial_decay_schedule_with_warmup'
]

optim = torch.optim.Adadelta(model.parameters())
with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."):
task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup")
optimizer, scheduler = task.configure_optimizers()

task = ClassificationTask(
model,
optimizer=optim,
scheduler="linear_schedule_with_warmup",
scheduler_kwargs={"num_warmup_steps": 0.1},
loss_fn=F.nll_loss,
)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=2)
ds = DummyDataset()
trainer.fit(task, train_dataloader=DataLoader(ds))
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR)
expected = get_linear_schedule_with_warmup.__name__
assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected