From 2f46b93520ce06811026915bd0ba793895385051 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 15 Sep 2021 11:48:27 +0530 Subject: [PATCH 01/22] Change optimizer Callables alone and scheduler to support Callables and string. --- flash/core/model.py | 41 +++++++++++++++++++++------------------- tests/core/test_model.py | 24 ++++++++++++----------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 8f173ce590..d20671d218 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -17,7 +17,7 @@ from abc import ABCMeta from copy import deepcopy from importlib import import_module -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union import pytorch_lightning as pl import torch @@ -29,6 +29,7 @@ from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn +from torch.functional import Tensor from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, Sampler @@ -302,9 +303,9 @@ def __init__( self, model: Optional[nn.Module] = None, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + optimizer: Optional[Callable[[Iterable[Tensor]], Optimizer]] = functools.partial(torch.optim.Adam), + # optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[str, Callable[..., _LRScheduler]]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, @@ -319,8 +320,8 @@ def __init__( self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) self.optimizer = optimizer self.scheduler = scheduler - self.optimizer_kwargs = optimizer_kwargs or {} - self.scheduler_kwargs = scheduler_kwargs or {} + # self.optimizer_kwargs: Dict[str, Any] = optimizer_kwargs or {} + self.scheduler_kwargs: Dict[str, Any] = scheduler_kwargs or {} self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) @@ -464,11 +465,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return self(batch) def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: - optimizer = self.optimizer - if not isinstance(self.optimizer, Optimizer): - self.optimizer_kwargs["lr"] = self.learning_rate - optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs) - if self.scheduler: + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + optimizer: Optimizer = self.optimizer(model_parameters, lr=self.learning_rate) + # if not isinstance(self.optimizer, Optimizer): + # self.optimizer_kwargs["lr"] = self.learning_rate + # optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs) + if self.scheduler is not None: return [optimizer], [self._instantiate_scheduler(optimizer)] return optimizer @@ -808,8 +810,8 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float] def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: scheduler = self.scheduler - if isinstance(scheduler, _LRScheduler): - return scheduler + # if isinstance(scheduler, Callable): + # return scheduler(optimizer) if isinstance(scheduler, str): scheduler_fn = self.schedulers.get(self.scheduler) num_training_steps: int = self.get_num_training_steps() @@ -818,12 +820,13 @@ def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), ) return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) - if issubclass(scheduler, _LRScheduler): - return scheduler(optimizer, **self.scheduler_kwargs) - raise MisconfigurationException( - "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " - f"or a built-in scheduler in {self.available_schedulers()}" - ) + # # if issubclass(scheduler, _LRScheduler): + # # return scheduler(optimizer, **self.scheduler_kwargs) + # raise MisconfigurationException( + # "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " + # f"or a built-in scheduler in {self.available_schedulers()}" + # ) + return scheduler(optimizer) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 3d3b53b111..945396bbac 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import math from itertools import chain from numbers import Number @@ -287,29 +288,30 @@ class Foo(ImageClassifier): def test_optimization(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) - optim = torch.optim.Adam(model.parameters()) - task = ClassificationTask(model, optimizer=optim, scheduler=None) + # optim = functools.partial(torch.optim.Adam) # (model.parameters()) + # task = ClassificationTask(model, optimizer=optim, scheduler=None) - optimizer = task.configure_optimizers() - assert optimizer == optim + # optimizer = task.configure_optimizers() + # assert optimizer == optim - task = ClassificationTask(model, optimizer=torch.optim.Adadelta, optimizer_kwargs={"eps": 0.5}, scheduler=None) + task = ClassificationTask(model, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), scheduler=None) optimizer = task.configure_optimizers() assert isinstance(optimizer, torch.optim.Adadelta) assert optimizer.defaults["eps"] == 0.5 task = ClassificationTask( model, - optimizer=torch.optim.Adadelta, - scheduler=torch.optim.lr_scheduler.StepLR, - scheduler_kwargs={"step_size": 1}, + optimizer=functools.partial(torch.optim.Adadelta), + scheduler=functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1), + # scheduler_kwargs={"step_size": 1}, ) optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - optim = torch.optim.Adadelta(model.parameters()) - task = ClassificationTask(model, optimizer=optim, scheduler=torch.optim.lr_scheduler.StepLR(optim, step_size=1)) + optim = functools.partial(torch.optim.Adadelta) # (model.parameters()) + scheduler = functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1) + task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) @@ -319,7 +321,7 @@ def test_optimization(tmpdir): assert isinstance(task.available_schedulers(), list) - optim = torch.optim.Adadelta(model.parameters()) + optim = functools.partial(torch.optim.Adadelta) # (model.parameters()) with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") optimizer, scheduler = task.configure_optimizers() From caefe689c2d0984de54d662d16b4b7e6a50be2d9 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 15 Sep 2021 20:26:23 +0530 Subject: [PATCH 02/22] Add Optimizer Registry and Update __init__ for all tasks. --- flash/audio/speech_recognition/model.py | 8 ++++---- flash/core/model.py | 13 ++++++++----- flash/core/optimizers/__init__.py | 2 ++ flash/core/optimizers/optimizers.py | 12 ++++++++++++ flash/core/{ => optimizers}/schedulers.py | 0 flash/graph/classification/model.py | 6 +++--- flash/image/classification/model.py | 6 +++--- flash/image/detection/model.py | 9 ++++----- flash/image/embedding/model.py | 6 +++--- flash/image/instance_segmentation/model.py | 9 ++++----- flash/image/keypoint_detection/model.py | 9 ++++----- flash/image/segmentation/model.py | 6 +++--- flash/image/style_transfer/model.py | 8 ++++---- flash/pointcloud/detection/model.py | 6 +++--- flash/pointcloud/segmentation/model.py | 6 +++--- flash/tabular/classification/model.py | 6 +++--- flash/template/classification/model.py | 6 +++--- flash/text/classification/model.py | 6 +++--- flash/text/question_answering/model.py | 6 +++--- flash/text/seq2seq/core/model.py | 6 +++--- flash/text/seq2seq/summarization/model.py | 6 +++--- flash/text/seq2seq/translation/model.py | 6 +++--- flash/video/classification/model.py | 6 +++--- tests/core/test_model.py | 15 +++++++-------- 24 files changed, 91 insertions(+), 78 deletions(-) create mode 100644 flash/core/optimizers/optimizers.py rename flash/core/{ => optimizers}/schedulers.py (100%) diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index e8e6137022..9486636c3d 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -13,7 +13,7 @@ # limitations under the License. import os import warnings -from typing import Any, Dict, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Type, Union import torch import torch.nn as nn @@ -54,8 +54,8 @@ class SpeechRecognition(Task): def __init__( self, backbone: str = "facebook/wav2vec2-base-960h", - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, learning_rate: float = 1e-5, @@ -71,7 +71,7 @@ def __init__( super().__init__( model=model, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, learning_rate=learning_rate, diff --git a/flash/core/model.py b/flash/core/model.py index d20671d218..b51b3cfe1e 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -17,7 +17,7 @@ from abc import ABCMeta from copy import deepcopy from importlib import import_module -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union import pytorch_lightning as pl import torch @@ -29,7 +29,6 @@ from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn -from torch.functional import Tensor from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, Sampler @@ -47,8 +46,8 @@ SerializerMapping, ) from flash.core.data.properties import ProcessState +from flash.core.optimizers import _OPTIMIZERS_REGISTRY, _SCHEDULERS_REGISTRY from flash.core.registry import FlashRegistry -from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.serve import Composition from flash.core.utilities.apply_func import get_callable_dict from flash.core.utilities.imports import requires @@ -295,6 +294,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check postprocess: :class:`~flash.core.data.process.Postprocess` to use as the default for this task. """ + optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY schedulers: FlashRegistry = _SCHEDULERS_REGISTRY required_extras: Optional[Union[str, List[str]]] = None @@ -303,12 +303,12 @@ def __init__( self, model: Optional[nn.Module] = None, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Optional[Callable[[Iterable[Tensor]], Optimizer]] = functools.partial(torch.optim.Adam), + learning_rate: float = 5e-5, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[str, Callable[..., _LRScheduler]]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - learning_rate: float = 5e-5, deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, @@ -465,6 +465,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return self(batch) def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: + if isinstance(self.optimizer, str): + self.optimizer = self.optimizers.get(self.optimizer) + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) optimizer: Optimizer = self.optimizer(model_parameters, lr=self.learning_rate) # if not isinstance(self.optimizer, Optimizer): diff --git a/flash/core/optimizers/__init__.py b/flash/core/optimizers/__init__.py index 76b1ef8a3e..0cf8cd1966 100644 --- a/flash/core/optimizers/__init__.py +++ b/flash/core/optimizers/__init__.py @@ -1,3 +1,5 @@ from flash.core.optimizers.lamb import LAMB # noqa: F401 from flash.core.optimizers.lars import LARS # noqa: F401 from flash.core.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR # noqa: F401 +from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY # noqa: F401 +from flash.core.optimizers.schedulers import _SCHEDULERS_REGISTRY # noqa: F401 diff --git a/flash/core/optimizers/optimizers.py b/flash/core/optimizers/optimizers.py new file mode 100644 index 0000000000..8f7d22b935 --- /dev/null +++ b/flash/core/optimizers/optimizers.py @@ -0,0 +1,12 @@ +from typing import Callable, List + +from torch import optim + +from flash.core.registry import FlashRegistry + +_OPTIMIZERS_REGISTRY = FlashRegistry("optimizer") + +_optimizers: List[Callable] = [getattr(optim, n) for n in dir(optim) if ("_" not in n)] + +for fn in _optimizers: + _OPTIMIZERS_REGISTRY(fn, name=fn.__name__) diff --git a/flash/core/schedulers.py b/flash/core/optimizers/schedulers.py similarity index 100% rename from flash/core/schedulers.py rename to flash/core/optimizers/schedulers.py diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py index 14fe6b2696..26eb8bc90d 100644 --- a/flash/graph/classification/model.py +++ b/flash/graph/classification/model.py @@ -109,8 +109,8 @@ def __init__( num_classes: int, hidden_channels: Union[List[int], int] = 512, loss_fn: Callable = F.cross_entropy, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Callable, Mapping, Sequence, None] = None, @@ -132,7 +132,7 @@ def __init__( model=model, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 89071ad71c..0f22a37e1b 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -79,8 +79,8 @@ def __init__( head: Optional[Union[FunctionType, nn.Module]] = None, pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -93,7 +93,7 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 9b52b314db..5922ef586b 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union import torch -from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from flash.core.adapter import AdapterTask @@ -61,8 +60,8 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "retinanet", pretrained: bool = True, - optimizer: Type[Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, learning_rate: float = 5e-3, @@ -85,7 +84,7 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, serializer=serializer or Preds(), diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index c803757ec5..3c255b3f10 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -64,8 +64,8 @@ def __init__( backbone: str = "resnet101", pretrained: bool = True, loss_fn: Callable = F.cross_entropy, - optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "SGD", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = (Accuracy()), @@ -76,7 +76,7 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 1d2b8ebf32..354611587f 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union import torch -from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from flash.core.adapter import AdapterTask @@ -61,8 +60,8 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "mask_rcnn", pretrained: bool = True, - optimizer: Type[Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, learning_rate: float = 5e-4, @@ -85,7 +84,7 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, serializer=serializer or Preds(), diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 36a7f361af..c6b054f367 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union import torch -from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from flash.core.adapter import AdapterTask @@ -62,8 +61,8 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "keypoint_rcnn", pretrained: bool = True, - optimizer: Type[Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, learning_rate: float = 5e-4, @@ -87,7 +86,7 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, serializer=serializer or Preds(), diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 83a654ec31..7460ab15d0 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -83,8 +83,8 @@ def __init__( head_kwargs: Optional[Dict] = None, pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -107,7 +107,7 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 86a6b723e5..48aece6cd1 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union import torch from torch import nn @@ -79,8 +79,8 @@ def __init__( content_weight: float = 1e5, style_layers: Union[Sequence[str], str] = ["relu1_2", "relu2_2", "relu3_3", "relu4_3"], style_weight: float = 1e10, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, learning_rate: float = 1e-3, @@ -112,7 +112,7 @@ def __init__( model=model, loss_fn=perceptual_loss, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, learning_rate=learning_rate, diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 155126d785..2cbbfd31d0 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -72,8 +72,8 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: Optional[nn.Module] = None, loss_fn: Optional[Callable] = None, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, @@ -88,7 +88,7 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index 9342a61758..909897e345 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -101,8 +101,8 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: Optional[nn.Module] = None, loss_fn: Optional[Callable] = torch.nn.functional.cross_entropy, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, @@ -119,7 +119,7 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index eadf5712b2..9bbb86560c 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -56,8 +56,8 @@ def __init__( num_classes: int, embedding_sizes: List[Tuple[int, int]] = None, loss_fn: Callable = F.cross_entropy, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -82,7 +82,7 @@ def __init__( model=model, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index e330fafdc8..5199651ee3 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -56,8 +56,8 @@ def __init__( backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128", backbone_kwargs: Optional[Dict] = None, loss_fn: Optional[Callable] = None, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, @@ -69,7 +69,7 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 2c4bf4b0d4..04b8ee72ce 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -62,8 +62,8 @@ def __init__( num_classes: int, backbone: str = "prajjwal1/bert-medium", loss_fn: Optional[Callable] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -85,7 +85,7 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 664f3d7dfe..399b3f42d1 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -95,8 +95,8 @@ def __init__( self, backbone: str = "distilbert-base-uncased", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -117,7 +117,7 @@ def __init__( super().__init__( loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index b8a8450d6f..ae294a981f 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -81,8 +81,8 @@ def __init__( self, backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -99,7 +99,7 @@ def __init__( super().__init__( loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index c2e1e8bb58..c21b948047 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -53,8 +53,8 @@ def __init__( self, backbone: str = "sshleifer/distilbart-xsum-1-1", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -70,7 +70,7 @@ def __init__( backbone=backbone, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 45e1711ca5..9894612f87 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -53,8 +53,8 @@ def __init__( self, backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -70,7 +70,7 @@ def __init__( backbone=backbone, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 57d8c46750..482dd22fa3 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -103,8 +103,8 @@ def __init__( backbone_kwargs: Optional[Dict] = None, pretrained: bool = True, loss_fn: Callable = F.cross_entropy, - optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "SGD", + # optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = Accuracy(), @@ -116,7 +116,7 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, + # optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, metrics=metrics, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 945396bbac..fa0ee4e26c 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -288,11 +288,11 @@ class Foo(ImageClassifier): def test_optimization(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) - # optim = functools.partial(torch.optim.Adam) # (model.parameters()) - # task = ClassificationTask(model, optimizer=optim, scheduler=None) + optim = "Adam" + task = ClassificationTask(model, optimizer=optim, scheduler=None) - # optimizer = task.configure_optimizers() - # assert optimizer == optim + optimizer = task.configure_optimizers() + assert isinstance(optimizer, torch.optim.Adam) task = ClassificationTask(model, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), scheduler=None) optimizer = task.configure_optimizers() @@ -301,15 +301,14 @@ def test_optimization(tmpdir): task = ClassificationTask( model, - optimizer=functools.partial(torch.optim.Adadelta), + optimizer="Adadelta", scheduler=functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1), - # scheduler_kwargs={"step_size": 1}, ) optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - optim = functools.partial(torch.optim.Adadelta) # (model.parameters()) + optim = functools.partial(torch.optim.Adadelta) scheduler = functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1) task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) optimizer, scheduler = task.configure_optimizers() @@ -321,7 +320,7 @@ def test_optimization(tmpdir): assert isinstance(task.available_schedulers(), list) - optim = functools.partial(torch.optim.Adadelta) # (model.parameters()) + optim = functools.partial(torch.optim.Adadelta) with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") optimizer, scheduler = task.configure_optimizers() From 7ea53a2f734b320bc042367922b38cdd7499e5cd Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 22 Sep 2021 21:47:53 +0530 Subject: [PATCH 03/22] Revamp scheduler parameter to use str, Callable, str with params. --- flash/core/model.py | 40 ++++--- flash/core/optimizers/schedulers.py | 26 ++++- tests/core/test_model.py | 159 +++++++++++++++++++--------- 3 files changed, 157 insertions(+), 68 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 900a106d52..9bc26408d2 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -318,8 +318,8 @@ def __init__( learning_rate: float = 5e-5, optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[str, Callable[..., _LRScheduler]]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[str, Callable, Tuple[str, Tuple[Any, ...]]]] = None, + # scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, preprocess: Optional[Preprocess] = None, @@ -332,8 +332,10 @@ def __init__( self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) self.optimizer = optimizer self.scheduler = scheduler + if isinstance(self.scheduler, Tuple): + assert isinstance(self.scheduler[0], str) # self.optimizer_kwargs: Dict[str, Any] = optimizer_kwargs or {} - self.scheduler_kwargs: Dict[str, Any] = scheduler_kwargs or {} + # self.scheduler_kwargs: Dict[str, Any] = scheduler_kwargs or {} self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) @@ -825,24 +827,36 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float] return round(num_warmup_steps) def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: - scheduler = self.scheduler + if isinstance(self.scheduler, Tuple): + scheduler_key = self.scheduler[0] + scheduler_args = self.scheduler[1] + return self.schedulers.get(scheduler_key)(optimizer, *scheduler_args) + + if isinstance(self.scheduler, str): + self.scheduler = self.schedulers.get(self.scheduler) # , with_metadata=True) + + # If provider is `huggingface`, then maybe use + # else: + # # Otherwise self.scheduler is a Callable + # pass + # if isinstance(scheduler, Callable): # return scheduler(optimizer) - if isinstance(scheduler, str): - scheduler_fn = self.schedulers.get(self.scheduler) - num_training_steps: int = self.get_num_training_steps() - num_warmup_steps: int = self._compute_warmup( - num_training_steps=num_training_steps, - num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), - ) - return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) + # if isinstance(scheduler, str): + # scheduler_fn = self.schedulers.get(self.scheduler) + # num_training_steps: int = self.get_num_training_steps() + # num_warmup_steps: int = self._compute_warmup( + # num_training_steps=num_training_steps, + # num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), + # ) + # return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) # # if issubclass(scheduler, _LRScheduler): # # return scheduler(optimizer, **self.scheduler_kwargs) # raise MisconfigurationException( # "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " # f"or a built-in scheduler in {self.available_schedulers()}" # ) - return scheduler(optimizer) + return self.scheduler(optimizer) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/flash/core/optimizers/schedulers.py b/flash/core/optimizers/schedulers.py index bfc1bc82b8..2be3122a0e 100644 --- a/flash/core/optimizers/schedulers.py +++ b/flash/core/optimizers/schedulers.py @@ -1,15 +1,33 @@ +import inspect from typing import Callable, List +from torch.optim import lr_scheduler +from torch.optim.lr_scheduler import _LRScheduler + from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE +from flash.core.utilities.providers import _HUGGINGFACE _SCHEDULERS_REGISTRY = FlashRegistry("scheduler") +schedulers: List[_LRScheduler] = [] +for n in dir(lr_scheduler): + sched = getattr(lr_scheduler, n) + + if inspect.isclass(sched) and sched != _LRScheduler and issubclass(sched, _LRScheduler): + schedulers.append(sched) + + +for scheduler in schedulers: + _SCHEDULERS_REGISTRY(scheduler, name=scheduler.__name__) + if _TRANSFORMERS_AVAILABLE: from transformers import optimization - functions: List[Callable] = [ - getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != "get_scheduler") - ] + functions: List[Callable] = [] + for n in dir(optimization): + if "get_" in n and n != "get_scheduler": + functions.append(getattr(optimization, n)) + for fn in functions: - _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:]) + _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:], providers=_HUGGINGFACE) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 25224a7d97..a33c8df1cc 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -24,7 +24,8 @@ import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities.exceptions import MisconfigurationException + +# from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F from torch.utils.data import DataLoader @@ -33,7 +34,7 @@ from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess -from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, Image +from flash.core.utilities.imports import _TABULAR_AVAILABLE, Image # , _TEXT_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING @@ -286,61 +287,117 @@ class Foo(ImageClassifier): assert Foo.available_backbones() == {} -def test_optimization(tmpdir): +@ClassificationTask.schedulers +def custom_steplr_configuration(optimizer): + return torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) - optim = "Adam" - task = ClassificationTask(model, optimizer=optim, scheduler=None) - optimizer = task.configure_optimizers() - assert isinstance(optimizer, torch.optim.Adam) +@pytest.mark.parametrize("optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5)]) +@pytest.mark.parametrize( + "sched", + [ + None, + "custom_steplr_configuration", + functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1), + ("StepLR", (1,)), + ], +) +def test_optimization(tmpdir, optim, sched): - task = ClassificationTask(model, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), scheduler=None) - optimizer = task.configure_optimizers() - assert isinstance(optimizer, torch.optim.Adadelta) - assert optimizer.defaults["eps"] == 0.5 + # Test optimizer: str and scheduler = None + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + task = ClassificationTask(model, optimizer=optim, scheduler=sched) - task = ClassificationTask( - model, - optimizer="Adadelta", - scheduler=functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1), - ) - optimizer, scheduler = task.configure_optimizers() - assert isinstance(optimizer[0], torch.optim.Adadelta) - assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - - optim = functools.partial(torch.optim.Adadelta) - scheduler = functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1) - task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) - optimizer, scheduler = task.configure_optimizers() - assert isinstance(optimizer[0], torch.optim.Adadelta) - assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - - if _TEXT_AVAILABLE: - from transformers.optimization import get_linear_schedule_with_warmup - - assert isinstance(task.available_schedulers(), list) - - optim = functools.partial(torch.optim.Adadelta) - with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): - task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") - optimizer, scheduler = task.configure_optimizers() - - task = ClassificationTask( - model, - optimizer=optim, - scheduler="linear_schedule_with_warmup", - scheduler_kwargs={"num_warmup_steps": 0.1}, - loss_fn=F.nll_loss, - ) - trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) - ds = DummyDataset() - trainer.fit(task, train_dataloader=DataLoader(ds)) + if sched is None: + optimizer = task.configure_optimizers() + assert isinstance(optimizer, torch.optim.Adadelta) + else: optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) - assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) - expected = get_linear_schedule_with_warmup.__name__ - assert scheduler[0].lr_lambdas[0].__qualname__.split(".")[0] == expected + assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + # # Test optimizer: Callable and scheduler = None + # task = ClassificationTask(model, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), scheduler=None) + # optimizer = task.configure_optimizers() + # assert isinstance(optimizer, torch.optim.Adadelta) + # assert optimizer.defaults["eps"] == 0.5 + + # # Test optimizer: str and scheduler: Callable + # task = ClassificationTask( + # model, + # optimizer="Adadelta", + # scheduler=functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1), + # ) + + # # Test optimizer: Callable and scheduler: Callable + # optim = functools.partial(torch.optim.Adadelta) + # scheduler = + # task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) + # optimizer, scheduler = task.configure_optimizers() + # assert isinstance(optimizer[0], torch.optim.Adadelta) + # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + # # Test optimizer: str and scheduler: Tuple + # task = ClassificationTask( + # model, + # optimizer="Adadelta", + # scheduler=, + # ) + # optimizer, scheduler = task.configure_optimizers() + # assert isinstance(optimizer[0], torch.optim.Adadelta) + # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + # # Test optimizer: Callable and scheduler: Tuple + # optim = functools.partial(torch.optim.Adadelta) + # scheduler = ("StepLR", (1,)) + # task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) + # optimizer, scheduler = task.configure_optimizers() + # assert isinstance(optimizer[0], torch.optim.Adadelta) + # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + # # Test optimizer: str and scheduler: str + # task = ClassificationTask( + # model, + # optimizer="Adadelta", + # scheduler="custom_steplr_configuration", + # ) + # optimizer, scheduler = task.configure_optimizers() + # assert isinstance(optimizer[0], torch.optim.Adadelta) + # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + # # Test optimizer: Callable and scheduler: str + # optim = functools.partial(torch.optim.Adadelta) + # scheduler = "custom_steplr_configuration" + # task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) + # optimizer, scheduler = task.configure_optimizers() + # assert isinstance(optimizer[0], torch.optim.Adadelta) + # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + + # if _TEXT_AVAILABLE: + # from transformers.optimization import get_linear_schedule_with_warmup + + # assert isinstance(task.available_schedulers(), list) + + # optim = functools.partial(torch.optim.Adadelta) + # with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): + # task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") + # optimizer, scheduler = task.configure_optimizers() + + # task = ClassificationTask( + # model, + # optimizer=optim, + # scheduler=("linear_schedule_with_warmup", ()), + # scheduler_kwargs={"num_warmup_steps": 0.1}, + # loss_fn=F.nll_loss, + # ) + # trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) + # ds = DummyDataset() + # trainer.fit(task, train_dataloader=DataLoader(ds)) + # optimizer, scheduler = task.configure_optimizers() + # assert isinstance(optimizer[0], torch.optim.Adadelta) + # assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) + # expected = get_linear_schedule_with_warmup.__name__ + # assert scheduler[0].lr_lambdas[0].__qualname__.split(".")[0] == expected def test_classification_task_metrics(): From 4cf6cddb66c8a183dccf954b9ac270b5335678b5 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Sun, 26 Sep 2021 21:08:29 +0530 Subject: [PATCH 04/22] Updated _instantiate_scheduler method to handle providers. Added support for HF transformers provided schedulers. --- flash/core/model.py | 71 ++++++++++++++++----------- tests/core/test_model.py | 102 +++++++++------------------------------ 2 files changed, 65 insertions(+), 108 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 705e5b1460..3f50923cc8 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -49,6 +49,7 @@ from flash.core.optimizers import _OPTIMIZERS_REGISTRY, _SCHEDULERS_REGISTRY from flash.core.registry import FlashRegistry from flash.core.serve import Composition +from flash.core.utilities import providers from flash.core.utilities.apply_func import get_callable_dict from flash.core.utilities.imports import requires @@ -824,36 +825,50 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float] return round(num_warmup_steps) def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: - if isinstance(self.scheduler, Tuple): - scheduler_key = self.scheduler[0] - scheduler_args = self.scheduler[1] - return self.schedulers.get(scheduler_key)(optimizer, *scheduler_args) + if isinstance(self.scheduler, Callable): + return self.scheduler(optimizer) if isinstance(self.scheduler, str): - self.scheduler = self.schedulers.get(self.scheduler) # , with_metadata=True) - - # If provider is `huggingface`, then maybe use - # else: - # # Otherwise self.scheduler is a Callable - # pass - - # if isinstance(scheduler, Callable): - # return scheduler(optimizer) - # if isinstance(scheduler, str): - # scheduler_fn = self.schedulers.get(self.scheduler) - # num_training_steps: int = self.get_num_training_steps() - # num_warmup_steps: int = self._compute_warmup( - # num_training_steps=num_training_steps, - # num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), - # ) - # return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) - # # if issubclass(scheduler, _LRScheduler): - # # return scheduler(optimizer, **self.scheduler_kwargs) - # raise MisconfigurationException( - # "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " - # f"or a built-in scheduler in {self.available_schedulers()}" - # ) - return self.scheduler(optimizer) + return self.schedulers.get(self.scheduler)(optimizer) + + if not isinstance(self.scheduler, Tuple): + raise TypeError("") # Add message + + # By default it is Tuple[str, Tuple[Any, ...]] type now. + scheduler_key = self.scheduler[0] + scheduler_args = self.scheduler[1] + + scheduler = self.schedulers.get(scheduler_key, with_metadata=True) + scheduler_fn = scheduler["fn"] + scheduler_metadata = scheduler["metadata"] + + if "providers" in scheduler_metadata.keys(): + if scheduler_metadata["providers"] == providers._HUGGINGFACE: + num_training_steps: int = self.get_num_training_steps() + num_warmup_steps: int = self._compute_warmup( + num_training_steps=num_training_steps, + num_warmup_steps=scheduler_args[0], # num_warmup_steps is the first arg in all schedulers + ) + if scheduler["name"] == "constant_schedule_with_warmup": + scheduler_args = (num_warmup_steps, *scheduler_args[1:]) + else: + scheduler_args = (num_warmup_steps, num_training_steps, *scheduler_args[1:]) + + # if isinstance(scheduler, str): + # scheduler_fn = self.schedulers.get(self.scheduler) + # num_training_steps: int = self.get_num_training_steps() + # num_warmup_steps: int = self._compute_warmup( + # num_training_steps=num_training_steps, + # num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), + # ) + # return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) + # # if issubclass(scheduler, _LRScheduler): + # # return scheduler(optimizer, **self.scheduler_kwargs) + # raise MisconfigurationException( + # "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " + # f"or a built-in scheduler in {self.available_schedulers()}" + # ) + return scheduler_fn(optimizer, *scheduler_args) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/tests/core/test_model.py b/tests/core/test_model.py index a33c8df1cc..193966bb68 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -302,9 +302,8 @@ def custom_steplr_configuration(optimizer): ("StepLR", (1,)), ], ) -def test_optimization(tmpdir, optim, sched): +def test_optimizers_and_schedulers(tmpdir, optim, sched): - # Test optimizer: str and scheduler = None model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) task = ClassificationTask(model, optimizer=optim, scheduler=sched) @@ -316,88 +315,31 @@ def test_optimization(tmpdir, optim, sched): assert isinstance(optimizer[0], torch.optim.Adadelta) assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - # # Test optimizer: Callable and scheduler = None - # task = ClassificationTask(model, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), scheduler=None) - # optimizer = task.configure_optimizers() - # assert isinstance(optimizer, torch.optim.Adadelta) - # assert optimizer.defaults["eps"] == 0.5 - - # # Test optimizer: str and scheduler: Callable - # task = ClassificationTask( - # model, - # optimizer="Adadelta", - # scheduler=functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1), - # ) - - # # Test optimizer: Callable and scheduler: Callable - # optim = functools.partial(torch.optim.Adadelta) - # scheduler = - # task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) - # optimizer, scheduler = task.configure_optimizers() - # assert isinstance(optimizer[0], torch.optim.Adadelta) - # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - - # # Test optimizer: str and scheduler: Tuple - # task = ClassificationTask( - # model, - # optimizer="Adadelta", - # scheduler=, - # ) - # optimizer, scheduler = task.configure_optimizers() - # assert isinstance(optimizer[0], torch.optim.Adadelta) - # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - - # # Test optimizer: Callable and scheduler: Tuple - # optim = functools.partial(torch.optim.Adadelta) - # scheduler = ("StepLR", (1,)) - # task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) - # optimizer, scheduler = task.configure_optimizers() - # assert isinstance(optimizer[0], torch.optim.Adadelta) - # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - - # # Test optimizer: str and scheduler: str - # task = ClassificationTask( - # model, - # optimizer="Adadelta", - # scheduler="custom_steplr_configuration", - # ) - # optimizer, scheduler = task.configure_optimizers() - # assert isinstance(optimizer[0], torch.optim.Adadelta) - # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - - # # Test optimizer: Callable and scheduler: str - # optim = functools.partial(torch.optim.Adadelta) - # scheduler = "custom_steplr_configuration" - # task = ClassificationTask(model, optimizer=optim, scheduler=scheduler) - # optimizer, scheduler = task.configure_optimizers() - # assert isinstance(optimizer[0], torch.optim.Adadelta) - # assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) - - # if _TEXT_AVAILABLE: - # from transformers.optimization import get_linear_schedule_with_warmup - - # assert isinstance(task.available_schedulers(), list) - - # optim = functools.partial(torch.optim.Adadelta) + +@pytest.mark.parametrize("optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5)]) +@pytest.mark.parametrize( + "sched", + [ + "constant_schedule", + ("cosine_schedule_with_warmup", (0.1,)), + ("cosine_with_hard_restarts_schedule_with_warmup", (0.1, 3)), + ], +) +def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched): + # with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): # task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") # optimizer, scheduler = task.configure_optimizers() - # task = ClassificationTask( - # model, - # optimizer=optim, - # scheduler=("linear_schedule_with_warmup", ()), - # scheduler_kwargs={"num_warmup_steps": 0.1}, - # loss_fn=F.nll_loss, - # ) - # trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) - # ds = DummyDataset() - # trainer.fit(task, train_dataloader=DataLoader(ds)) - # optimizer, scheduler = task.configure_optimizers() - # assert isinstance(optimizer[0], torch.optim.Adadelta) - # assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) - # expected = get_linear_schedule_with_warmup.__name__ - # assert scheduler[0].lr_lambdas[0].__qualname__.split(".")[0] == expected + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + task = ClassificationTask(model, optimizer=optim, scheduler=sched, loss_fn=F.nll_loss) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) + ds = DummyDataset() + trainer.fit(task, train_dataloader=DataLoader(ds)) + + optimizer, scheduler = task.configure_optimizers() + assert isinstance(optimizer[0], torch.optim.Adadelta) + assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) def test_classification_task_metrics(): From 440aef23e184c4911d21b11ba02f062a6a4b0ec1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 27 Sep 2021 10:07:03 +0100 Subject: [PATCH 05/22] wip --- README.md | 19 ++++++++++ flash/core/model.py | 55 ++++++++++++++++------------- flash/core/optimizers/optimizers.py | 18 +++++++++- flash/core/optimizers/schedulers.py | 15 ++++++-- flash/core/utilities/imports.py | 2 ++ flash/image/classification/model.py | 2 +- tests/core/test_model.py | 26 ++++++++++++-- 7 files changed, 105 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 3cee739f3e..9112d23e61 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,25 @@ In detail, the following methods are currently implemented: * **[metaoptnet](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_metaoptnet.py)** : from Lee *et al.* 2019, [Meta-Learning with Differentiable Convex Optimization](https://arxiv.org/abs/1904.03758) * **[anil](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_anil.py)** : from Raghu *et al.* 2020, [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML](https://arxiv.org/abs/1909.09157) + +### Flash Optimizers / Schedulers + +With Flash, swapping among 40+ optimizers and 15 + schedulers recipes are simple. Find the list of available optimizers, schedulers as follows: + +```py +ImageClassifier.available_optimizers() +# ['A2GradExp', ..., 'Yogi'] + +ImageClassifier.available_schedulers() +# ['CosineAnnealingLR', 'CosineAnnealingWarmRestarts', ..., 'polynomial_decay_schedule_with_warmup'] +``` + +Once you've chosen, create the model: + +```py +model = ImageClassifier(backbone="resnet18", optimizer='yogi', scheduler="cosine_with_hard_restarts_schedule_with_warmup", num_classes=2) +``` + ### Flash Transforms diff --git a/flash/core/model.py b/flash/core/model.py index 3f50923cc8..7ec5a85570 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -24,6 +24,7 @@ import torchmetrics from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.enums import LightningEnum @@ -316,7 +317,9 @@ def __init__( learning_rate: float = 5e-5, optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[str, Callable, Tuple[str, Tuple[Any, ...]]]] = None, + scheduler: Optional[ + Union[Union[str, Dict[str, Any]], Callable, Tuple[Union[str, Dict[str, Any]], Tuple[Any, ...]]] + ] = None, # scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, @@ -331,9 +334,7 @@ def __init__( self.optimizer = optimizer self.scheduler = scheduler if isinstance(self.scheduler, Tuple): - assert isinstance(self.scheduler[0], str) - # self.optimizer_kwargs: Dict[str, Any] = optimizer_kwargs or {} - # self.scheduler_kwargs: Dict[str, Any] = scheduler_kwargs or {} + assert isinstance(self.scheduler[0], (str, dict)) self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) @@ -479,7 +480,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: if isinstance(self.optimizer, str): - self.optimizer = self.optimizers.get(self.optimizer) + self.optimizer = self.optimizers.get(self.optimizer.lower()) model_parameters = filter(lambda p: p.requires_grad, self.parameters()) optimizer: Optimizer = self.optimizer(model_parameters, lr=self.learning_rate) @@ -782,6 +783,13 @@ def get_backbone_details(cls, key) -> List[str]: return [] return list(inspect.signature(registry.get(key)).parameters.items()) + @classmethod + def available_optimizers(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "optimizers", None) + if registry is None: + return [] + return registry.available_keys() + @classmethod def available_schedulers(cls) -> List[str]: registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None) @@ -824,7 +832,7 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float] num_warmup_steps *= num_training_steps return round(num_warmup_steps) - def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: + def _instantiate_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: if isinstance(self.scheduler, Callable): return self.scheduler(optimizer) @@ -832,13 +840,22 @@ def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: return self.schedulers.get(self.scheduler)(optimizer) if not isinstance(self.scheduler, Tuple): - raise TypeError("") # Add message + raise TypeError("The scheduler arguments should be provided as a tuple.") - # By default it is Tuple[str, Tuple[Any, ...]] type now. - scheduler_key = self.scheduler[0] + scheduler_key_or_config = self.scheduler[0] scheduler_args = self.scheduler[1] - scheduler = self.schedulers.get(scheduler_key, with_metadata=True) + if isinstance(scheduler_key_or_config, dict): + scheduler_key = scheduler_key_or_config["scheduler"] + scheduler_config = scheduler_key_or_config + else: + scheduler_key = scheduler_key_or_config + scheduler_config = _get_default_scheduler_config() + scheduler_config["interval"] = None + + scheduler_config = deepcopy(scheduler_config) + + scheduler = self.schedulers.get(scheduler_key.lower(), with_metadata=True) scheduler_fn = scheduler["fn"] scheduler_metadata = scheduler["metadata"] @@ -854,21 +871,9 @@ def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: else: scheduler_args = (num_warmup_steps, num_training_steps, *scheduler_args[1:]) - # if isinstance(scheduler, str): - # scheduler_fn = self.schedulers.get(self.scheduler) - # num_training_steps: int = self.get_num_training_steps() - # num_warmup_steps: int = self._compute_warmup( - # num_training_steps=num_training_steps, - # num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"), - # ) - # return scheduler_fn(optimizer, num_warmup_steps, num_training_steps) - # # if issubclass(scheduler, _LRScheduler): - # # return scheduler(optimizer, **self.scheduler_kwargs) - # raise MisconfigurationException( - # "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` " - # f"or a built-in scheduler in {self.available_schedulers()}" - # ) - return scheduler_fn(optimizer, *scheduler_args) + scheduler_config["scheduler"] = scheduler_fn(optimizer, *scheduler_args) + scheduler_config["interval"] = scheduler_config["interval"] or scheduler_metadata["interval"] + return scheduler_config def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/flash/core/optimizers/optimizers.py b/flash/core/optimizers/optimizers.py index 8f7d22b935..c54f99905a 100644 --- a/flash/core/optimizers/optimizers.py +++ b/flash/core/optimizers/optimizers.py @@ -1,12 +1,28 @@ +from inspect import isclass from typing import Callable, List from torch import optim from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE _OPTIMIZERS_REGISTRY = FlashRegistry("optimizer") _optimizers: List[Callable] = [getattr(optim, n) for n in dir(optim) if ("_" not in n)] for fn in _optimizers: - _OPTIMIZERS_REGISTRY(fn, name=fn.__name__) + _OPTIMIZERS_REGISTRY(fn, name=fn.__name__.lower()) + +if _TORCH_OPTIMIZER_AVAILABLE: + import torch_optimizer + + _optimizers: List[Callable] = [ + getattr(torch_optimizer, n) + for n in dir(torch_optimizer) + if ("_" not in n) and isclass(getattr(torch_optimizer, n)) + ] + + for fn in _optimizers: + name = fn.__name__.lower() + if name not in _OPTIMIZERS_REGISTRY: + _OPTIMIZERS_REGISTRY(fn, name=name) diff --git a/flash/core/optimizers/schedulers.py b/flash/core/optimizers/schedulers.py index 2be3122a0e..2e36c41ec4 100644 --- a/flash/core/optimizers/schedulers.py +++ b/flash/core/optimizers/schedulers.py @@ -2,13 +2,21 @@ from typing import Callable, List from torch.optim import lr_scheduler -from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import ( + _LRScheduler, + CosineAnnealingLR, + CosineAnnealingWarmRestarts, + CyclicLR, + MultiStepLR, + StepLR, +) from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE _SCHEDULERS_REGISTRY = FlashRegistry("scheduler") +_STEP_SCHEDULERS = (StepLR, MultiStepLR, CosineAnnealingLR, CyclicLR, CosineAnnealingWarmRestarts) schedulers: List[_LRScheduler] = [] for n in dir(lr_scheduler): @@ -19,7 +27,8 @@ for scheduler in schedulers: - _SCHEDULERS_REGISTRY(scheduler, name=scheduler.__name__) + interval = "step" if issubclass(scheduler, _STEP_SCHEDULERS) else "epoch" + _SCHEDULERS_REGISTRY(scheduler, name=scheduler.__name__.lower(), interval=interval) if _TRANSFORMERS_AVAILABLE: from transformers import optimization @@ -30,4 +39,4 @@ functions.append(getattr(optimization, n)) for fn in functions: - _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:], providers=_HUGGINGFACE) + _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:].lower(), providers=_HUGGINGFACE, interval="step") diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index f138eaf37e..abaffb3193 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -102,6 +102,8 @@ def _compare_version(package: str, op, version) -> bool: _VISSL_AVAILABLE = _module_available("vissl") and _module_available("classy_vision") _ALBUMENTATIONS_AVAILABLE = _module_available("albumentations") _BAAL_AVAILABLE = _module_available("baal") +_TORCH_OPTIMIZER_AVAILABLE = _module_available("torch_optimizer") + if _PIL_AVAILABLE: from PIL import Image # noqa: F401 diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index b428b456c3..d1a1147aff 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -82,7 +82,7 @@ def __init__( pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 193966bb68..0e01c523ed 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -23,6 +23,7 @@ import pytest import pytorch_lightning as pl import torch +import torchmetrics from pytorch_lightning.callbacks import Callback # from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -148,7 +149,7 @@ def __init__(self, child): # ================================ -@pytest.mark.parametrize("metrics", [None, pl.metrics.Accuracy(), {"accuracy": pl.metrics.Accuracy()}]) +@pytest.mark.parametrize("metrics", [None, torchmetrics.Accuracy(), {"accuracy": torchmetrics.Accuracy()}]) def test_classificationtask_train(tmpdir: str, metrics: Any): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -300,12 +301,20 @@ def custom_steplr_configuration(optimizer): "custom_steplr_configuration", functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1), ("StepLR", (1,)), + ( + { + "scheduler": "StepLR", + "interval": None, # after epoch is over + }, + (1,), + ), ], ) def test_optimizers_and_schedulers(tmpdir, optim, sched): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) task = ClassificationTask(model, optimizer=optim, scheduler=sched) + train_dl = torch.utils.data.DataLoader(DummyDataset()) if sched is None: optimizer = task.configure_optimizers() @@ -313,7 +322,20 @@ def test_optimizers_and_schedulers(tmpdir, optim, sched): else: optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) - assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR) + scheduler = scheduler[0] + if isinstance(scheduler, dict): + assert isinstance(scheduler["scheduler"], torch.optim.lr_scheduler.StepLR) + assert scheduler["interval"] == "step" + else: + assert isinstance(scheduler, torch.optim.lr_scheduler.StepLR) + + # generate a checkpoint + trainer = flash.Trainer( + default_root_dir=tmpdir, + limit_train_batches=10, + max_epochs=1, + ) + trainer.fit(task, train_dl) @pytest.mark.parametrize("optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5)]) From 06e7722ea38d9e35ae33aeec747042f7f766417b Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 29 Sep 2021 16:42:21 +0530 Subject: [PATCH 06/22] Updated scheduler parameter to take input as type Tuple[str, Dict[str, Any]]. Added necessary tests as well. --- flash/core/model.py | 124 +++++++++++++++++++++++++++++---------- tests/core/test_model.py | 62 +++++++++++--------- 2 files changed, 128 insertions(+), 58 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 7ec5a85570..84777355f9 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -315,11 +315,9 @@ def __init__( model: Optional[nn.Module] = None, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, learning_rate: float = 5e-5, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[ - Union[Union[str, Dict[str, Any]], Callable, Tuple[Union[str, Dict[str, Any]], Tuple[Any, ...]]] - ] = None, + scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, # scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, @@ -332,9 +330,12 @@ def __init__( self.model = model self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) self.optimizer = optimizer + if isinstance(self.optimizer, Tuple): + assert isinstance(self.optimizer[0], str) or isinstance(self.optimizer[0], Callable) + self.scheduler = scheduler if isinstance(self.scheduler, Tuple): - assert isinstance(self.scheduler[0], (str, dict)) + assert isinstance(self.scheduler[0], str) or isinstance(self.scheduler[0], Callable) self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) @@ -480,10 +481,31 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: if isinstance(self.optimizer, str): - self.optimizer = self.optimizers.get(self.optimizer.lower()) + optimizer_fn = self.optimizers.get(self.optimizer.lower()) + _optimizers_kwargs: Dict[str, Any] = {} + elif isinstance(self.optimizer, Callable): + optimizer_fn = self.optimizer + _optimizers_kwargs: Dict[str, Any] = {} + elif isinstance(self.optimizer, Tuple): + optimizer_fn: Callable = None + optimizer_key: str = self.optimizer[0] + + if not isinstance(optimizer_key, str): + raise MisconfigurationException( + f"Please provide a key from the available optimizers. \ + Refer to {self.__class__.__name__}.available_optimizers" + ) + + optimizer_fn = self.optimizers.get(optimizer_key.lower()) + _optimizers_kwargs: Dict[str, Any] = self.optimizer[1] + else: + raise TypeError( + f"Optimizer should be of type string or callable or tuple(string, dictionary) \ + but got {type(self.optimizer)}." + ) model_parameters = filter(lambda p: p.requires_grad, self.parameters()) - optimizer: Optimizer = self.optimizer(model_parameters, lr=self.learning_rate) + optimizer: Optimizer = optimizer_fn(model_parameters, lr=self.learning_rate, **_optimizers_kwargs) # if not isinstance(self.optimizer, Optimizer): # self.optimizer_kwargs["lr"] = self.learning_rate # optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs) @@ -833,46 +855,86 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float] return round(num_warmup_steps) def _instantiate_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: - if isinstance(self.scheduler, Callable): - return self.scheduler(optimizer) + if isinstance(self.scheduler, str) or isinstance(self.scheduler, Callable): + + # Get values based in type. + if isinstance(self.scheduler, str): + _scheduler = self.schedulers.get(self.scheduler, with_metadata=True) + scheduler_fn: Callable = _scheduler["fn"] + scheduler_metadata: Dict[str, Any] = _scheduler["metadata"] + else: + scheduler_fn: Callable = self.scheduler + + # Generate the output: could be a scheduler object or a scheduler config. + sched_output: Union[_LRScheduler, Dict[str, Any]] = scheduler_fn(optimizer) + + # Create and/or update a scheduler configuration + scheduler_config = _get_default_scheduler_config() + if isinstance(sched_output, _LRScheduler): + scheduler_config["scheduler"] = sched_output + if isinstance(self.scheduler, str) and "interval" in scheduler_metadata.keys(): + scheduler_config["interval"] = scheduler_metadata["interval"] + elif isinstance(sched_output, dict): + for key, value in sched_output.items(): + scheduler_config[key] = value + else: + if isinstance(self.scheduler, str): + message = "register a custom callable" + else: + message = "provide a callable" + raise MisconfigurationException( + f"Please {message} that outputs either an LR Scheduler or a scheduler condifguration." + ) - if isinstance(self.scheduler, str): - return self.schedulers.get(self.scheduler)(optimizer) + return scheduler_config if not isinstance(self.scheduler, Tuple): raise TypeError("The scheduler arguments should be provided as a tuple.") - scheduler_key_or_config = self.scheduler[0] - scheduler_args = self.scheduler[1] + if not isinstance(self.scheduler[0], str): + raise TypeError( + f"The first value in scheduler argument tuple should be either a string or a callable \ + but got {type(self.scheduler[0])}." + ) - if isinstance(scheduler_key_or_config, dict): - scheduler_key = scheduler_key_or_config["scheduler"] - scheduler_config = scheduler_key_or_config - else: - scheduler_key = scheduler_key_or_config - scheduler_config = _get_default_scheduler_config() - scheduler_config["interval"] = None + # Separate the key and the kwargs. + scheduler_key_or_fn: Union[str, Callable] = self.scheduler[0] + scheduler_kwargs_and_config: Dict[str, Any] = self.scheduler[1] + + # Get the default scheduler config. + scheduler_config: Dict[str, Any] = _get_default_scheduler_config() + scheduler_config["interval"] = None + + # Update scheduler config from the kwargs and pop the keys from the kwargs at the same time. + for config_key, config_value in scheduler_config.items(): + scheduler_config[config_key] = scheduler_kwargs_and_config.pop(config_key, None) or config_value - scheduler_config = deepcopy(scheduler_config) + # Create a new copy of the kwargs. + scheduler_kwargs = deepcopy(scheduler_kwargs_and_config) + assert all(config_key not in scheduler_kwargs.keys() for config_key in scheduler_config.keys()) - scheduler = self.schedulers.get(scheduler_key.lower(), with_metadata=True) - scheduler_fn = scheduler["fn"] - scheduler_metadata = scheduler["metadata"] + # Retreive the scheduler callable with metadata from the registry. + _scheduler = self.schedulers.get(scheduler_key_or_fn.lower(), with_metadata=True) + scheduler_fn: Callable = _scheduler["fn"] + scheduler_metadata: Dict[str, Any] = _scheduler["metadata"] + # Make necessary adjustment to the kwargs based on the provider of the scheduler. if "providers" in scheduler_metadata.keys(): if scheduler_metadata["providers"] == providers._HUGGINGFACE: num_training_steps: int = self.get_num_training_steps() num_warmup_steps: int = self._compute_warmup( num_training_steps=num_training_steps, - num_warmup_steps=scheduler_args[0], # num_warmup_steps is the first arg in all schedulers + num_warmup_steps=scheduler_kwargs["num_warmup_steps"], ) - if scheduler["name"] == "constant_schedule_with_warmup": - scheduler_args = (num_warmup_steps, *scheduler_args[1:]) - else: - scheduler_args = (num_warmup_steps, num_training_steps, *scheduler_args[1:]) + scheduler_kwargs["num_warmup_steps"] = num_warmup_steps + scheduler_kwargs["num_training_steps"] = num_training_steps + + # Set the scheduler in the config. + scheduler_config["scheduler"] = scheduler_fn(optimizer, **scheduler_kwargs) - scheduler_config["scheduler"] = scheduler_fn(optimizer, *scheduler_args) - scheduler_config["interval"] = scheduler_config["interval"] or scheduler_metadata["interval"] + # Update the interval in sched config just in case it has NoneType. + if "interval" in scheduler_metadata.keys(): + scheduler_config["interval"] = scheduler_config["interval"] or scheduler_metadata["interval"] return scheduler_config def _load_from_state_dict( diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 0e01c523ed..c553ba8f14 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -25,8 +25,6 @@ import torch import torchmetrics from pytorch_lightning.callbacks import Callback - -# from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F from torch.utils.data import DataLoader @@ -35,7 +33,7 @@ from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess -from flash.core.utilities.imports import _TABULAR_AVAILABLE, Image # , _TEXT_AVAILABLE +from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, Image from flash.image import ImageClassificationData, ImageClassifier from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING @@ -289,28 +287,39 @@ class Foo(ImageClassifier): @ClassificationTask.schedulers -def custom_steplr_configuration(optimizer): - return torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) +def custom_steplr_configuration_return_as_instance(optimizer): + return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) + + +@ClassificationTask.schedulers +def custom_steplr_configuration_return_as_dict(optimizer): + return { + "scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=10), + "name": "A_Really_Cool_Name", + "interval": "step", + "frequency": 1, + "reduce_on_plateau": False, + "monitor": None, + "strict": True, + "opt_idx": None, + } -@pytest.mark.parametrize("optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5)]) @pytest.mark.parametrize( - "sched", + "optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5), ("Adadelta", {"eps": 0.5})] +) +@pytest.mark.parametrize( + "sched, interval", [ - None, - "custom_steplr_configuration", - functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1), - ("StepLR", (1,)), - ( - { - "scheduler": "StepLR", - "interval": None, # after epoch is over - }, - (1,), - ), + (None, "epoch"), + ("custom_steplr_configuration_return_as_instance", "epoch"), + ("custom_steplr_configuration_return_as_dict", "step"), + (functools.partial(torch.optim.lr_scheduler.StepLR, step_size=10), "epoch"), + (("StepLR", {"step_size": 10}), "step"), + (("StepLR", {"step_size": 10, "interval": None}), "step"), ], ) -def test_optimizers_and_schedulers(tmpdir, optim, sched): +def test_optimizers_and_schedulers(tmpdir, optim, sched, interval): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) task = ClassificationTask(model, optimizer=optim, scheduler=sched) @@ -322,12 +331,10 @@ def test_optimizers_and_schedulers(tmpdir, optim, sched): else: optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) + scheduler = scheduler[0] - if isinstance(scheduler, dict): - assert isinstance(scheduler["scheduler"], torch.optim.lr_scheduler.StepLR) - assert scheduler["interval"] == "step" - else: - assert isinstance(scheduler, torch.optim.lr_scheduler.StepLR) + assert isinstance(scheduler["scheduler"], torch.optim.lr_scheduler.StepLR) + assert scheduler["interval"] == interval # generate a checkpoint trainer = flash.Trainer( @@ -338,13 +345,14 @@ def test_optimizers_and_schedulers(tmpdir, optim, sched): trainer.fit(task, train_dl) +@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") @pytest.mark.parametrize("optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5)]) @pytest.mark.parametrize( "sched", [ "constant_schedule", - ("cosine_schedule_with_warmup", (0.1,)), - ("cosine_with_hard_restarts_schedule_with_warmup", (0.1, 3)), + ("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}), + ("cosine_with_hard_restarts_schedule_with_warmup", {"num_warmup_steps": 0.1, "num_cycles": 3}), ], ) def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched): @@ -361,7 +369,7 @@ def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched): optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) - assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) + assert isinstance(scheduler[0]["scheduler"], torch.optim.lr_scheduler.LambdaLR) def test_classification_task_metrics(): From 8ab54bd9727417c413c96931d8bfe1f77ffcdbff Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 29 Sep 2021 17:40:04 +0530 Subject: [PATCH 07/22] Update naming of scheduler parameter to lr_scheduler. --- flash/core/model.py | 104 +++++++++++++++++++-------------------- tests/core/test_model.py | 8 +-- 2 files changed, 54 insertions(+), 58 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 84777355f9..f74394b4d2 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -299,6 +299,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check model: Model to use for the task. loss_fn: Loss function for training optimizer: Optimizer to use for training, defaults to :class:`torch.optim.Adam`. + lr_scheduler: The scheduler or scheduler class to use. metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to ``5e-5``. preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task. @@ -306,7 +307,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check """ optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY - schedulers: FlashRegistry = _SCHEDULERS_REGISTRY + lr_schedulers: FlashRegistry = _SCHEDULERS_REGISTRY required_extras: Optional[Union[str, List[str]]] = None @@ -316,9 +317,7 @@ def __init__( loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, learning_rate: float = 5e-5, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, - # scheduler_kwargs: Optional[Dict[str, Any]] = None, + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, preprocess: Optional[Preprocess] = None, @@ -333,9 +332,9 @@ def __init__( if isinstance(self.optimizer, Tuple): assert isinstance(self.optimizer[0], str) or isinstance(self.optimizer[0], Callable) - self.scheduler = scheduler - if isinstance(self.scheduler, Tuple): - assert isinstance(self.scheduler[0], str) or isinstance(self.scheduler[0], Callable) + self.lr_scheduler = lr_scheduler + if isinstance(self.lr_scheduler, Tuple): + assert isinstance(self.lr_scheduler[0], str) or isinstance(self.lr_scheduler[0], Callable) self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) @@ -506,11 +505,8 @@ def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_ model_parameters = filter(lambda p: p.requires_grad, self.parameters()) optimizer: Optimizer = optimizer_fn(model_parameters, lr=self.learning_rate, **_optimizers_kwargs) - # if not isinstance(self.optimizer, Optimizer): - # self.optimizer_kwargs["lr"] = self.learning_rate - # optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs) - if self.scheduler is not None: - return [optimizer], [self._instantiate_scheduler(optimizer)] + if self.lr_scheduler is not None: + return [optimizer], [self._instantiate_lr_scheduler(optimizer)] return optimizer @staticmethod @@ -813,8 +809,8 @@ def available_optimizers(cls) -> List[str]: return registry.available_keys() @classmethod - def available_schedulers(cls) -> List[str]: - registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None) + def available_lr_schedulers(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "lr_schedulers", None) if registry is None: return [] return registry.available_keys() @@ -854,31 +850,31 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float] num_warmup_steps *= num_training_steps return round(num_warmup_steps) - def _instantiate_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: - if isinstance(self.scheduler, str) or isinstance(self.scheduler, Callable): + def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: + if isinstance(self.lr_scheduler, str) or isinstance(self.lr_scheduler, Callable): # Get values based in type. - if isinstance(self.scheduler, str): - _scheduler = self.schedulers.get(self.scheduler, with_metadata=True) - scheduler_fn: Callable = _scheduler["fn"] - scheduler_metadata: Dict[str, Any] = _scheduler["metadata"] + if isinstance(self.lr_scheduler, str): + _lr_scheduler = self.lr_schedulers.get(self.lr_scheduler, with_metadata=True) + lr_scheduler_fn: Callable = _lr_scheduler["fn"] + lr_scheduler_metadata: Dict[str, Any] = _lr_scheduler["metadata"] else: - scheduler_fn: Callable = self.scheduler + lr_scheduler_fn: Callable = self.lr_scheduler - # Generate the output: could be a scheduler object or a scheduler config. - sched_output: Union[_LRScheduler, Dict[str, Any]] = scheduler_fn(optimizer) + # Generate the output: could be a lr_scheduler object or a lr_scheduler config. + sched_output: Union[_LRScheduler, Dict[str, Any]] = lr_scheduler_fn(optimizer) - # Create and/or update a scheduler configuration - scheduler_config = _get_default_scheduler_config() + # Create and/or update a lr_scheduler configuration + lr_scheduler_config = _get_default_scheduler_config() if isinstance(sched_output, _LRScheduler): - scheduler_config["scheduler"] = sched_output - if isinstance(self.scheduler, str) and "interval" in scheduler_metadata.keys(): - scheduler_config["interval"] = scheduler_metadata["interval"] + lr_scheduler_config["scheduler"] = sched_output + if isinstance(self.lr_scheduler, str) and "interval" in lr_scheduler_metadata.keys(): + lr_scheduler_config["interval"] = lr_scheduler_metadata["interval"] elif isinstance(sched_output, dict): for key, value in sched_output.items(): - scheduler_config[key] = value + lr_scheduler_config[key] = value else: - if isinstance(self.scheduler, str): + if isinstance(self.lr_scheduler, str): message = "register a custom callable" else: message = "provide a callable" @@ -886,56 +882,56 @@ def _instantiate_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: f"Please {message} that outputs either an LR Scheduler or a scheduler condifguration." ) - return scheduler_config + return lr_scheduler_config - if not isinstance(self.scheduler, Tuple): + if not isinstance(self.lr_scheduler, Tuple): raise TypeError("The scheduler arguments should be provided as a tuple.") - if not isinstance(self.scheduler[0], str): + if not isinstance(self.lr_scheduler[0], str): raise TypeError( f"The first value in scheduler argument tuple should be either a string or a callable \ - but got {type(self.scheduler[0])}." + but got {type(self.lr_scheduler[0])}." ) # Separate the key and the kwargs. - scheduler_key_or_fn: Union[str, Callable] = self.scheduler[0] - scheduler_kwargs_and_config: Dict[str, Any] = self.scheduler[1] + lr_scheduler_key_or_fn: Union[str, Callable] = self.lr_scheduler[0] + lr_scheduler_kwargs_and_config: Dict[str, Any] = self.lr_scheduler[1] # Get the default scheduler config. - scheduler_config: Dict[str, Any] = _get_default_scheduler_config() - scheduler_config["interval"] = None + lr_scheduler_config: Dict[str, Any] = _get_default_scheduler_config() + lr_scheduler_config["interval"] = None # Update scheduler config from the kwargs and pop the keys from the kwargs at the same time. - for config_key, config_value in scheduler_config.items(): - scheduler_config[config_key] = scheduler_kwargs_and_config.pop(config_key, None) or config_value + for config_key, config_value in lr_scheduler_config.items(): + lr_scheduler_config[config_key] = lr_scheduler_kwargs_and_config.pop(config_key, None) or config_value # Create a new copy of the kwargs. - scheduler_kwargs = deepcopy(scheduler_kwargs_and_config) - assert all(config_key not in scheduler_kwargs.keys() for config_key in scheduler_config.keys()) + lr_scheduler_kwargs = deepcopy(lr_scheduler_kwargs_and_config) + assert all(config_key not in lr_scheduler_kwargs.keys() for config_key in lr_scheduler_config.keys()) # Retreive the scheduler callable with metadata from the registry. - _scheduler = self.schedulers.get(scheduler_key_or_fn.lower(), with_metadata=True) - scheduler_fn: Callable = _scheduler["fn"] - scheduler_metadata: Dict[str, Any] = _scheduler["metadata"] + _lr_scheduler = self.lr_schedulers.get(lr_scheduler_key_or_fn.lower(), with_metadata=True) + lr_scheduler_fn: Callable = _lr_scheduler["fn"] + lr_scheduler_metadata: Dict[str, Any] = _lr_scheduler["metadata"] # Make necessary adjustment to the kwargs based on the provider of the scheduler. - if "providers" in scheduler_metadata.keys(): - if scheduler_metadata["providers"] == providers._HUGGINGFACE: + if "providers" in lr_scheduler_metadata.keys(): + if lr_scheduler_metadata["providers"] == providers._HUGGINGFACE: num_training_steps: int = self.get_num_training_steps() num_warmup_steps: int = self._compute_warmup( num_training_steps=num_training_steps, - num_warmup_steps=scheduler_kwargs["num_warmup_steps"], + num_warmup_steps=lr_scheduler_kwargs["num_warmup_steps"], ) - scheduler_kwargs["num_warmup_steps"] = num_warmup_steps - scheduler_kwargs["num_training_steps"] = num_training_steps + lr_scheduler_kwargs["num_warmup_steps"] = num_warmup_steps + lr_scheduler_kwargs["num_training_steps"] = num_training_steps # Set the scheduler in the config. - scheduler_config["scheduler"] = scheduler_fn(optimizer, **scheduler_kwargs) + lr_scheduler_config["scheduler"] = lr_scheduler_fn(optimizer, **lr_scheduler_kwargs) # Update the interval in sched config just in case it has NoneType. - if "interval" in scheduler_metadata.keys(): - scheduler_config["interval"] = scheduler_config["interval"] or scheduler_metadata["interval"] - return scheduler_config + if "interval" in lr_scheduler_metadata.keys(): + lr_scheduler_config["interval"] = lr_scheduler_config["interval"] or lr_scheduler_metadata["interval"] + return lr_scheduler_config def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/tests/core/test_model.py b/tests/core/test_model.py index c553ba8f14..968cdcd743 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -286,12 +286,12 @@ class Foo(ImageClassifier): assert Foo.available_backbones() == {} -@ClassificationTask.schedulers +@ClassificationTask.lr_schedulers def custom_steplr_configuration_return_as_instance(optimizer): return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) -@ClassificationTask.schedulers +@ClassificationTask.lr_schedulers def custom_steplr_configuration_return_as_dict(optimizer): return { "scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=10), @@ -322,7 +322,7 @@ def custom_steplr_configuration_return_as_dict(optimizer): def test_optimizers_and_schedulers(tmpdir, optim, sched, interval): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) - task = ClassificationTask(model, optimizer=optim, scheduler=sched) + task = ClassificationTask(model, optimizer=optim, lr_scheduler=sched) train_dl = torch.utils.data.DataLoader(DummyDataset()) if sched is None: @@ -362,7 +362,7 @@ def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched): # optimizer, scheduler = task.configure_optimizers() model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) - task = ClassificationTask(model, optimizer=optim, scheduler=sched, loss_fn=F.nll_loss) + task = ClassificationTask(model, optimizer=optim, lr_scheduler=sched, loss_fn=F.nll_loss) trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) ds = DummyDataset() trainer.fit(task, train_dataloader=DataLoader(ds)) From 617e53aa268c27d08965b6482cf5903f48756b00 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 29 Sep 2021 18:04:19 +0530 Subject: [PATCH 08/22] Update optimizer and lr_scheduler parameter across all tasks. --- flash/audio/speech_recognition/model.py | 17 +++++------------ flash/graph/classification/model.py | 17 +++++------------ flash/image/classification/model.py | 18 +++++------------- flash/image/detection/model.py | 19 +++++-------------- flash/image/embedding/model.py | 19 +++++-------------- flash/image/instance_segmentation/model.py | 19 +++++-------------- flash/image/keypoint_detection/model.py | 18 +++++------------- flash/image/segmentation/model.py | 16 +++++----------- flash/image/style_transfer/model.py | 16 +++++----------- flash/pointcloud/detection/model.py | 16 +++++----------- flash/pointcloud/segmentation/model.py | 18 +++++++----------- flash/tabular/classification/model.py | 12 ++++-------- flash/template/classification/model.py | 18 +++++++----------- flash/text/classification/model.py | 18 +++++++----------- flash/text/question_answering/model.py | 18 +++++++----------- flash/text/seq2seq/core/model.py | 18 +++++++----------- flash/text/seq2seq/summarization/model.py | 18 +++++++----------- flash/text/seq2seq/translation/model.py | 17 +++++------------ flash/video/classification/model.py | 18 +++++++----------- 19 files changed, 108 insertions(+), 222 deletions(-) diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index 9486636c3d..5e77a919bd 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -13,11 +13,10 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union import torch import torch.nn as nn -from torch.optim.lr_scheduler import _LRScheduler from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding @@ -40,9 +39,7 @@ class SpeechRecognition(Task): backbone: Any speech recognition model from `HuggingFace/transformers `_. optimizer: Optimizer to use for training. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. learning_rate: Learning rate to use for training, defaults to ``1e-3``. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. """ @@ -54,10 +51,8 @@ class SpeechRecognition(Task): def __init__( self, backbone: str = "facebook/wav2vec2-base-960h", - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, learning_rate: float = 1e-5, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): @@ -71,9 +66,7 @@ def __init__( super().__init__( model=model, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, learning_rate=learning_rate, serializer=serializer, ) diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py index 26eb8bc90d..7a6a400110 100644 --- a/flash/graph/classification/model.py +++ b/flash/graph/classification/model.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union import torch from torch import nn from torch.nn import functional as F from torch.nn import Linear -from torch.optim.lr_scheduler import _LRScheduler from flash.core.classification import ClassificationTask from flash.core.utilities.imports import _GRAPH_AVAILABLE @@ -92,9 +91,7 @@ class GraphClassifier(ClassificationTask): hidden_channels: Hidden dimension sizes. loss_fn: Loss function for training, defaults to cross entropy. optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to `1e-3` model: GraphNN used, defaults to BaseGraphModel. @@ -109,10 +106,8 @@ def __init__( num_classes: int, hidden_channels: Union[List[int], int] = 512, loss_fn: Callable = F.cross_entropy, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-3, model: torch.nn.Module = None, @@ -132,9 +127,7 @@ def __init__( model=model, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, ) diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index d1a1147aff..9bf9b97c5a 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import FunctionType -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union -import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn -from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import Metric from flash.core.classification import ClassificationAdapterTask, Labels @@ -56,9 +54,7 @@ def fn_resnet(pretrained: bool = True): which loads the default supervised pretrained weights. loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inheriting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature @@ -81,10 +77,8 @@ def __init__( head: Optional[Union[FunctionType, nn.Module]] = None, pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-3, multi_label: bool = False, @@ -138,9 +132,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, multi_label=multi_label, serializer=serializer or Labels(multi_label=multi_label), ) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 5922ef586b..cf53ac223f 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union - -import torch -from torch.optim.lr_scheduler import _LRScheduler +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer @@ -41,9 +38,7 @@ class ObjectDetector(AdapterTask): metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. Changing this argument currently has no effect. optimizer: The optimizer to use for training. Can either be the actual class or the class name. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. Has no effect for custom models. learning_rate: The learning rate to use for training @@ -60,10 +55,8 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "retinanet", pretrained: bool = True, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, learning_rate: float = 5e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, @@ -84,9 +77,7 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, serializer=serializer or Preds(), ) diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index cff9949608..78a1a63074 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import Any, Callable, Dict, List, Optional, Type, Union - -import torch -from torch.optim.lr_scheduler import _LRScheduler +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from flash.core.adapter import AdapterTask from flash.core.data.data_source import DefaultDataKeys @@ -55,9 +52,7 @@ class ImageEmbedder(AdapterTask): backbone: VISSL backbone, defaults to ``resnet``. pretrained: Use a pretrained backbone, defaults to ``False``. optimizer: Optimizer to use for training and finetuning, defaults to :class:`torch.optim.SGD`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. learning_rate: Learning rate to use for training, defaults to ``1e-3``. backbone_kwargs: arguments to be passed to VISSL backbones, i.e. ``vision_transformer`` and ``resnet``. training_strategy_kwargs: arguments passed to VISSL loss function, projection head and training hooks. @@ -77,10 +72,8 @@ def __init__( pretraining_transform: str, backbone: str = "resnet", pretrained: bool = False, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "SGD", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, learning_rate: float = 1e-3, backbone_kwargs: Optional[Dict[str, Any]] = None, training_strategy_kwargs: Optional[Dict[str, Any]] = None, @@ -110,9 +103,7 @@ def __init__( super().__init__( adapter=adapter, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, learning_rate=learning_rate, ) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 354611587f..bd5986c7de 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union - -import torch -from torch.optim.lr_scheduler import _LRScheduler +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer @@ -41,9 +38,7 @@ class InstanceSegmentation(AdapterTask): metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. Changing this argument currently has no effect. optimizer: The optimizer to use for training. Can either be the actual class or the class name. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. Has no effect for custom models. learning_rate: The learning rate to use for training @@ -60,10 +55,8 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "mask_rcnn", pretrained: bool = True, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, @@ -84,9 +77,7 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, serializer=serializer or Preds(), ) diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index c6b054f367..81ddb64ed6 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union - -import torch -from torch.optim.lr_scheduler import _LRScheduler +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer @@ -41,9 +38,7 @@ class KeypointDetector(AdapterTask): metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. Changing this argument currently has no effect. optimizer: The optimizer to use for training. Can either be the actual class or the class name. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. Has no effect for custom models. learning_rate: The learning rate to use for training @@ -61,10 +56,8 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "keypoint_rcnn", pretrained: bool = True, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, @@ -87,8 +80,7 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, serializer=serializer or Preds(), ) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 7460ab15d0..c1e9c17f1a 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import nn from torch.nn import functional as F -from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import IoU, Metric from flash.core.classification import ClassificationTask @@ -54,9 +53,7 @@ class SemanticSegmentation(ClassificationTask): pretrained: Use a pretrained backbone. loss_fn: Loss function for training. optimizer: Optimizer to use for training. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature @@ -83,10 +80,8 @@ def __init__( head_kwargs: Optional[Dict] = None, pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-3, multi_label: bool = False, @@ -108,8 +103,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, serializer=serializer or SegmentationLabels(), diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 48aece6cd1..de39ecd019 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Union import torch from torch import nn -from torch.optim.lr_scheduler import _LRScheduler from flash.core.data.data_source import DefaultDataKeys from flash.core.data.process import Serializer @@ -59,9 +58,7 @@ class StyleTransfer(Task): content_weight: The weight associated with the content loss. A lower value will lose content over style. style_layers: Layers from the backbone to derive the style loss from. optimizer: Optimizer to use for training the model. - optimizer_kwargs: Optimizer keywords arguments. - scheduler: Scheduler to use for training the model. - scheduler_kwargs: Scheduler keywords arguments. + lr_scheduler: Scheduler to use for training the model. learning_rate: Learning rate to use for training, defaults to ``1e-3``. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. """ @@ -79,10 +76,8 @@ def __init__( content_weight: float = 1e5, style_layers: Union[Sequence[str], str] = ["relu1_2", "relu2_2", "relu3_3", "relu4_3"], style_weight: float = 1e10, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, learning_rate: float = 1e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): @@ -113,8 +108,7 @@ def __init__( loss_fn=perceptual_loss, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, learning_rate=learning_rate, serializer=serializer, ) diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 2d60d116e6..7224499886 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union import torch import torchmetrics from torch import nn -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader, Sampler from flash.core.data.auto_dataset import BaseAutoDataset @@ -49,9 +48,7 @@ class PointCloudObjectDetector(Task): loss_fn: The loss function to use. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. optimizer: The optimizer or optimizer class to use. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. @@ -72,10 +69,8 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: Optional[nn.Module] = None, loss_fn: Optional[Callable] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudObjectDetectorSerializer(), @@ -89,8 +84,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, serializer=serializer, diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index 266c6766e1..c2c1e64506 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch import torchmetrics @@ -19,7 +19,6 @@ from torch import nn from torch.nn import functional as F from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader, Sampler from torchmetrics import IoU @@ -80,9 +79,9 @@ class PointCloudSegmentation(ClassificationTask): loss_fn: The loss function to use. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. optimizer: The optimizer or optimizer class to use. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + + lr_scheduler: The scheduler or scheduler class to use. + metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. @@ -101,10 +100,8 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: Optional[nn.Module] = None, loss_fn: Optional[Callable] = torch.nn.functional.cross_entropy, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, multi_label: bool = False, @@ -120,8 +117,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 9bbb86560c..cb3ead134d 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch.nn import functional as F -from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import Metric from flash.core.classification import ClassificationTask, Probabilities @@ -56,10 +55,8 @@ def __init__( num_classes: int, embedding_sizes: List[Tuple[int, int]] = None, loss_fn: Callable = F.cross_entropy, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, multi_label: bool = False, @@ -83,8 +80,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index 5199651ee3..6d47545db4 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union import torch import torchmetrics from torch import nn -from torch.optim.lr_scheduler import _LRScheduler from flash.core.classification import ClassificationTask, Labels from flash.core.data.data_source import DefaultDataKeys @@ -37,9 +36,9 @@ class TemplateSKLearnClassifier(ClassificationTask): loss_fn: The loss function to use. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. optimizer: The optimizer or optimizer class to use. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + + lr_scheduler: The scheduler or scheduler class to use. + metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. @@ -56,10 +55,8 @@ def __init__( backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128", backbone_kwargs: Optional[Dict] = None, loss_fn: Optional[Callable] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, multi_label: bool = False, @@ -70,8 +67,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 04b8ee72ce..2ddcce0a0b 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -13,11 +13,10 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch from pytorch_lightning import Callback -from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import Metric from flash.core.classification import ClassificationTask, Labels @@ -40,9 +39,9 @@ class TextClassifier(ClassificationTask): num_classes: Number of classes to classify. backbone: A model to use to compute text features can be any BERT model from HuggingFace/transformersimage . optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + + lr_scheduler: The scheduler or scheduler class to use. + metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature @@ -62,10 +61,8 @@ def __init__( num_classes: int, backbone: str = "prajjwal1/bert-medium", loss_fn: Optional[Callable] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, multi_label: bool = False, @@ -86,8 +83,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 399b3f42d1..deab94793c 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -19,14 +19,13 @@ import collections import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_info from torch import Tensor -from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import Metric from flash.core.data.data_source import DefaultDataKeys @@ -69,9 +68,9 @@ class QuestionAnsweringTask(Task): backbone: backbone model to use for the task. loss_fn: Loss function for training. optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + + lr_scheduler: The scheduler or scheduler class to use. + metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `3e-4` @@ -95,10 +94,8 @@ def __init__( self, backbone: str = "distilbert-base-uncased", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, enable_ort: bool = False, @@ -118,8 +115,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, ) diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index ae294a981f..f71c44bfe6 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -13,13 +13,12 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_info from torch import Tensor -from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import Metric from flash.core.finetuning import FlashBaseFinetuning @@ -63,9 +62,9 @@ class Seq2SeqTask(Task): Args: loss_fn: Loss function for training optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + + lr_scheduler: The scheduler or scheduler class to use. + metrics: Metrics to compute for training and evaluation. Changing this argument currently has no effect learning_rate: Learning rate to use for training, defaults to `3e-4` val_target_max_length: Maximum length of targets in validation. Defaults to `128` @@ -81,10 +80,8 @@ def __init__( self, backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, val_target_max_length: Optional[int] = None, @@ -100,8 +97,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, ) diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index c21b948047..3ba6cf8703 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch -from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import Metric from flash.text.seq2seq.core.metrics import RougeMetric @@ -36,9 +35,9 @@ class SummarizationTask(Seq2SeqTask): backbone: backbone model to use for the task. loss_fn: Loss function for training. optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + + lr_scheduler: The scheduler or scheduler class to use. + metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `3e-4` @@ -53,10 +52,8 @@ def __init__( self, backbone: str = "sshleifer/distilbart-xsum-1-1", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-5, val_target_max_length: Optional[int] = None, @@ -71,8 +68,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, val_target_max_length=val_target_max_length, diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 9894612f87..f01efaa0ab 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union -import torch -from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import Metric from flash.text.seq2seq.core.metrics import BLEUScore @@ -36,9 +34,7 @@ class TranslationTask(Seq2SeqTask): backbone: backbone model to use for the task. loss_fn: Loss function for training. optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + lr_scheduler: The scheduler or scheduler class to use. metrics: Metrics to compute for training and evaluation. Defauls to calculating the BLEU metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `1e-5` @@ -53,10 +49,8 @@ def __init__( self, backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-5, val_target_max_length: Optional[int] = 128, @@ -71,8 +65,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, val_target_max_length=val_target_max_length, diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 482dd22fa3..c089708468 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import FunctionType -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch from pytorch_lightning import LightningModule @@ -22,7 +22,6 @@ from torch import nn from torch.nn import functional as F from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DistributedSampler from torchmetrics import Accuracy, Metric @@ -82,9 +81,9 @@ class VideoClassifier(ClassificationTask): pretrained: Use a pretrained backbone, defaults to ``True``. loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. - optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). - scheduler: The scheduler or scheduler class to use. - scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + + lr_scheduler: The scheduler or scheduler class to use. + metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature @@ -103,10 +102,8 @@ def __init__( backbone_kwargs: Optional[Dict] = None, pretrained: bool = True, loss_fn: Callable = F.cross_entropy, - optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "SGD", - # optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "SGD", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = Accuracy(), learning_rate: float = 1e-3, head: Optional[Union[FunctionType, nn.Module]] = None, @@ -117,8 +114,7 @@ def __init__( loss_fn=loss_fn, optimizer=optimizer, # optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, serializer=serializer or Labels(), From 7a3029b4cbf1386141ac3004fa88ecdd977d6a36 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 29 Sep 2021 18:41:43 +0530 Subject: [PATCH 09/22] Updated optimizer registration code to compare with optimizer types and not optimizer names. --- flash/core/optimizers/optimizers.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/flash/core/optimizers/optimizers.py b/flash/core/optimizers/optimizers.py index c54f99905a..9ea305b7f4 100644 --- a/flash/core/optimizers/optimizers.py +++ b/flash/core/optimizers/optimizers.py @@ -8,21 +8,28 @@ _OPTIMIZERS_REGISTRY = FlashRegistry("optimizer") -_optimizers: List[Callable] = [getattr(optim, n) for n in dir(optim) if ("_" not in n)] +_optimizers: List[Callable] = [] +for n in dir(optim): + _optimizer = getattr(optim, n) + + if isclass(_optimizer) and _optimizer != optim.Optimizer and issubclass(_optimizer, optim.Optimizer): + _optimizers.append(_optimizer) for fn in _optimizers: _OPTIMIZERS_REGISTRY(fn, name=fn.__name__.lower()) + if _TORCH_OPTIMIZER_AVAILABLE: import torch_optimizer - _optimizers: List[Callable] = [ - getattr(torch_optimizer, n) - for n in dir(torch_optimizer) - if ("_" not in n) and isclass(getattr(torch_optimizer, n)) - ] + _torch_optimizers: List[Callable] = [] + for n in dir(torch_optimizer): + _optimizer = getattr(torch_optimizer, n) + + if isclass(_optimizer) and issubclass(_optimizer, optim.Optimizer): + _torch_optimizers.append(_optimizer) - for fn in _optimizers: + for fn in _torch_optimizers: name = fn.__name__.lower() if name not in _OPTIMIZERS_REGISTRY: _OPTIMIZERS_REGISTRY(fn, name=name) From d36c451b4644d4c4787687493b260c725b9aa014 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 29 Sep 2021 20:48:22 +0530 Subject: [PATCH 10/22] Added tests for Errors and Exceptions. --- flash/core/model.py | 34 +++++++++++---------- requirements/test.txt | 1 + tests/core/test_model.py | 65 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 16 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index f74394b4d2..446ca50dc4 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -329,12 +329,7 @@ def __init__( self.model = model self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) self.optimizer = optimizer - if isinstance(self.optimizer, Tuple): - assert isinstance(self.optimizer[0], str) or isinstance(self.optimizer[0], Callable) - self.lr_scheduler = lr_scheduler - if isinstance(self.lr_scheduler, Tuple): - assert isinstance(self.lr_scheduler[0], str) or isinstance(self.lr_scheduler[0], Callable) self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) @@ -480,6 +475,11 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: if isinstance(self.optimizer, str): + if self.optimizer.lower() not in self.available_optimizers(): + raise KeyError( + f"""Please provide a valid optimizer name and make sure it is registerd with the Optimizer registry. + Use `{self.__class__.__name__}.available_optimizers`.""" + ) optimizer_fn = self.optimizers.get(self.optimizer.lower()) _optimizers_kwargs: Dict[str, Any] = {} elif isinstance(self.optimizer, Callable): @@ -490,17 +490,16 @@ def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_ optimizer_key: str = self.optimizer[0] if not isinstance(optimizer_key, str): - raise MisconfigurationException( - f"Please provide a key from the available optimizers. \ - Refer to {self.__class__.__name__}.available_optimizers" + raise TypeError( + f"PThe first value in scheduler argument tuple should be a string but got {type(optimizer_key)}." ) optimizer_fn = self.optimizers.get(optimizer_key.lower()) _optimizers_kwargs: Dict[str, Any] = self.optimizer[1] else: raise TypeError( - f"Optimizer should be of type string or callable or tuple(string, dictionary) \ - but got {type(self.optimizer)}." + f"""Optimizer should be of type string or callable or tuple(string, dictionary) + but got {type(self.optimizer)}.""" ) model_parameters = filter(lambda p: p.requires_grad, self.parameters()) @@ -852,10 +851,15 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float] def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: if isinstance(self.lr_scheduler, str) or isinstance(self.lr_scheduler, Callable): + if isinstance(self.lr_scheduler, str) and self.lr_scheduler.lower() not in self.available_lr_schedulers(): + raise KeyError( + f"""Please provide a valid key and make sure it is registerd with the Scheduler registry. + Use `{self.__class__.__name__}.available_schedulers`.""" + ) # Get values based in type. if isinstance(self.lr_scheduler, str): - _lr_scheduler = self.lr_schedulers.get(self.lr_scheduler, with_metadata=True) + _lr_scheduler = self.lr_schedulers.get(self.lr_scheduler.lower(), with_metadata=True) lr_scheduler_fn: Callable = _lr_scheduler["fn"] lr_scheduler_metadata: Dict[str, Any] = _lr_scheduler["metadata"] else: @@ -889,12 +893,12 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: if not isinstance(self.lr_scheduler[0], str): raise TypeError( - f"The first value in scheduler argument tuple should be either a string or a callable \ - but got {type(self.lr_scheduler[0])}." + f"""The first value in scheduler argument tuple should be a string but got + {type(self.lr_scheduler[0])}.""" ) # Separate the key and the kwargs. - lr_scheduler_key_or_fn: Union[str, Callable] = self.lr_scheduler[0] + lr_scheduler_key: str = self.lr_scheduler[0] lr_scheduler_kwargs_and_config: Dict[str, Any] = self.lr_scheduler[1] # Get the default scheduler config. @@ -910,7 +914,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: assert all(config_key not in lr_scheduler_kwargs.keys() for config_key in lr_scheduler_config.keys()) # Retreive the scheduler callable with metadata from the registry. - _lr_scheduler = self.lr_schedulers.get(lr_scheduler_key_or_fn.lower(), with_metadata=True) + _lr_scheduler = self.lr_schedulers.get(lr_scheduler_key.lower(), with_metadata=True) lr_scheduler_fn: Callable = _lr_scheduler["fn"] lr_scheduler_metadata: Dict[str, Any] = _lr_scheduler["metadata"] diff --git a/requirements/test.txt b/requirements/test.txt index 3fecfe24d9..8465c065a2 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -14,3 +14,4 @@ isort #mypy scikit-learn pytest_mock +torch_optimizer diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 968cdcd743..8b13d2dbc3 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -25,6 +25,7 @@ import torch import torchmetrics from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F from torch.utils.data import DataLoader @@ -33,7 +34,7 @@ from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess -from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, Image +from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE, Image from flash.image import ImageClassificationData, ImageClassifier from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING @@ -345,6 +346,22 @@ def test_optimizers_and_schedulers(tmpdir, optim, sched, interval): trainer.fit(task, train_dl) +@pytest.mark.skipif(not _TORCH_OPTIMIZER_AVAILABLE, reason="torch_optimizer isn't installed.") +@pytest.mark.parametrize("optim", ["Yogi"]) +def test_external_optimizers_torch_optimizer(tmpdir, optim): + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + task = ClassificationTask(model, optimizer=optim, lr_scheduler=None, loss_fn=F.nll_loss) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) + ds = DummyDataset() + trainer.fit(task, train_dataloader=DataLoader(ds)) + + from torch_optimizer import Yogi + + optimizer = task.configure_optimizers() + assert isinstance(optimizer, Yogi) + + @pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") @pytest.mark.parametrize("optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5)]) @pytest.mark.parametrize( @@ -372,6 +389,52 @@ def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched): assert isinstance(scheduler[0]["scheduler"], torch.optim.lr_scheduler.LambdaLR) +def test_errors_and_exceptions_optimizers_and_schedulers(): + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + + with pytest.raises(TypeError): + task = ClassificationTask(model, optimizer=[1, 2, 3, 4], lr_scheduler=None) + task.configure_optimizers() + + with pytest.raises(KeyError): + task = ClassificationTask(model, optimizer="not_a_valid_key", lr_scheduler=None) + task.configure_optimizers() + + with pytest.raises(TypeError): + task = ClassificationTask( + model, optimizer=(["not", "a", "valid", "type"], {"random_kwarg": 10}), lr_scheduler=None + ) + task.configure_optimizers() + + with pytest.raises(KeyError): + task = ClassificationTask(model, optimizer="Adam", lr_scheduler="not_a_valid_key") + task.configure_optimizers() + + @ClassificationTask.lr_schedulers + def i_will_create_a_misconfiguration_exception(optimizer): + return "Done. Created." + + with pytest.raises(MisconfigurationException): + task = ClassificationTask(model, optimizer="Adam", lr_scheduler="i_will_create_a_misconfiguration_exception") + task.configure_optimizers() + + with pytest.raises(MisconfigurationException): + task = ClassificationTask(model, optimizer="Adam", lr_scheduler=i_will_create_a_misconfiguration_exception) + task.configure_optimizers() + + with pytest.raises(TypeError): + task = ClassificationTask(model, optimizer="Adam", lr_scheduler=["not", "a", "valid", "type"]) + task.configure_optimizers() + + with pytest.raises(TypeError): + task = ClassificationTask( + model, optimizer="Adam", lr_scheduler=(["not", "a", "valid", "type"], {"random_kwarg": 10}) + ) + task.configure_optimizers() + + pass + + def test_classification_task_metrics(): train_dataset = FixedDataset([0, 1]) val_dataset = FixedDataset([1, 1]) From 061454b170e2e3c650bddfcd452f3b3ba839aa92 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Thu, 30 Sep 2021 14:01:37 +0530 Subject: [PATCH 11/22] Update README with examples on using the API. --- README.md | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9112d23e61..a7e99bc8d3 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,35 @@ ImageClassifier.available_schedulers() Once you've chosen, create the model: ```py -model = ImageClassifier(backbone="resnet18", optimizer='yogi', scheduler="cosine_with_hard_restarts_schedule_with_warmup", num_classes=2) +#### The optimizer of choice can be passed as a +# - String value +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None) + +# - Callable +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.AdaDelta, eps=0.5), lr_scheduler=None) + +# - Tuple[string, dict]: (The dict takes in the optimizer kwargs) +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("AdaDelta", {"epa": 0.5}), lr_scheduler=None) + +#### The scheduler of choice can be passed as a +# - String value +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="constant_schedule") + +# - Callable +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=functools.partial(CyclicLR, step_size_up=1500, mode='exp_range', gamma=0.5)) + +# - Tuple[string, dict]: (The dict takes in the scheduler kwargs) +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=("StepLR", {"step_size": 10])) +``` + +You can also register you own custom scheduler recipes beforeahand and use them shown as above: + +```py +@ImageClassifier.lr_schedulers +def my_steplr_recipe(optimizer): + return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) + +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_steplr_recipe") ``` ### Flash Transforms From c611aa80b8d70f4832aaa5b98df43aa3243bc7c6 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Thu, 30 Sep 2021 18:58:06 +0530 Subject: [PATCH 12/22] Update skipif condition only to check for transformers library instead of all text libraries. --- tests/core/test_model.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 8b13d2dbc3..b23a148c8b 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -34,7 +34,7 @@ from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess -from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE, Image +from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image from flash.image import ImageClassificationData, ImageClassifier from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING @@ -362,7 +362,7 @@ def test_external_optimizers_torch_optimizer(tmpdir, optim): assert isinstance(optimizer, Yogi) -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TRANSFORMERS_AVAILABLE, reason="transformers library isn't installed.") @pytest.mark.parametrize("optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5)]) @pytest.mark.parametrize( "sched", @@ -373,11 +373,6 @@ def test_external_optimizers_torch_optimizer(tmpdir, optim): ], ) def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched): - - # with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): - # task = ClassificationTask(model, optimizer=optim, scheduler="linear_schedule_with_warmup") - # optimizer, scheduler = task.configure_optimizers() - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) task = ClassificationTask(model, optimizer=optim, lr_scheduler=sched, loss_fn=F.nll_loss) trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) From e158802b8a3ddc03d5f4b23ac9414b7260c4fac3 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Fri, 1 Oct 2021 13:52:55 +0530 Subject: [PATCH 13/22] Update newly added Face Detection Task. --- flash/image/face_detection/model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index 5d7c6e0445..edea34baf6 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import pytorch_lightning as pl import torch from torch import nn -from torch.optim import Optimizer from flash.core.data.data_source import DefaultDataKeys from flash.core.data.process import Preprocess, Serializer @@ -57,6 +56,7 @@ class FaceDetector(Task): metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. Changing this argument currently has no effect. optimizer: The optimizer to use for training. Can either be the actual class or the class name. + lr_scheduler: The scheduler or scheduler class to use. learning_rate: The learning rate to use for training """ @@ -68,7 +68,8 @@ def __init__( pretrained: bool = True, loss=None, metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, - optimizer: Type[Optimizer] = torch.optim.AdamW, + optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", + lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, learning_rate: float = 1e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, preprocess: Optional[Preprocess] = None, @@ -87,6 +88,7 @@ def __init__( metrics=metrics or {"AP": ff.metric.AveragePrecision()}, # TODO: replace with torch metrics MAP learning_rate=learning_rate, optimizer=optimizer, + lr_scheduler=lr_scheduler, serializer=serializer or DetectionLabels(), preprocess=preprocess or FaceDetectionPreprocess(), ) From 20eacaf19f1db497609b26fcfeb13202c7aaa327 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 13 Oct 2021 18:15:21 +0530 Subject: [PATCH 14/22] Changes from code review, Add new input method to lr_scheduler parameter, update tests. --- README.md | 4 +- flash/audio/speech_recognition/model.py | 4 +- flash/core/model.py | 210 ++++++++++++--------- flash/graph/classification/model.py | 4 +- flash/image/classification/model.py | 4 +- flash/image/detection/model.py | 4 +- flash/image/embedding/model.py | 4 +- flash/image/face_detection/model.py | 4 +- flash/image/instance_segmentation/model.py | 4 +- flash/image/keypoint_detection/model.py | 4 +- flash/image/segmentation/model.py | 4 +- flash/image/style_transfer/model.py | 4 +- flash/pointcloud/detection/model.py | 4 +- flash/pointcloud/segmentation/model.py | 4 +- flash/tabular/classification/model.py | 4 +- flash/template/classification/model.py | 4 +- flash/text/classification/model.py | 4 +- flash/text/question_answering/model.py | 4 +- flash/text/seq2seq/core/model.py | 4 +- flash/text/seq2seq/summarization/model.py | 4 +- flash/text/seq2seq/translation/model.py | 4 +- flash/video/classification/model.py | 4 +- tests/core/test_model.py | 13 +- 23 files changed, 194 insertions(+), 113 deletions(-) diff --git a/README.md b/README.md index 1e8054670f..2d18bb8848 100644 --- a/README.md +++ b/README.md @@ -196,10 +196,10 @@ Once you've chosen, create the model: model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None) # - Callable -model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.AdaDelta, eps=0.5), lr_scheduler=None) +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), lr_scheduler=None) # - Tuple[string, dict]: (The dict takes in the optimizer kwargs) -model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("AdaDelta", {"epa": 0.5}), lr_scheduler=None) +model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("Adadelta", {"epa": 0.5}), lr_scheduler=None) #### The scheduler of choice can be passed as a # - String value diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index 5e77a919bd..d259c4ae75 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -52,7 +52,9 @@ def __init__( self, backbone: str = "facebook/wav2vec2-base-960h", optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, learning_rate: float = 1e-5, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): diff --git a/flash/core/model.py b/flash/core/model.py index 04229a63da..ed8728a26e 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -326,7 +326,9 @@ def __init__( loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, learning_rate: float = 5e-5, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, preprocess: Optional[Preprocess] = None, @@ -496,29 +498,44 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A batch = torch.stack(batch) return self(batch) + def _get_optimizer_class_from_registry(self, optimizer_key: str) -> Optimizer: + if optimizer_key.lower() not in self.available_optimizers(): + raise KeyError( + f"Please provide a valid optimizer name and make sure it is registerd with the Optimizer registry." + f"\nUse `{self.__class__.__name__}.available_optimizers()` to list the available optimizers." + f"\nList of available Optimizers: {self.available_optimizers()}." + ) + optimizer_fn = self.optimizers.get(optimizer_key.lower()) + return optimizer_fn + def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: """Implement how optimizer and optionally learning rate schedulers should be configured.""" if isinstance(self.optimizer, str): - if self.optimizer.lower() not in self.available_optimizers(): - raise KeyError( - f"""Please provide a valid optimizer name and make sure it is registerd with the Optimizer registry. - Use `{self.__class__.__name__}.available_optimizers`.""" - ) - optimizer_fn = self.optimizers.get(self.optimizer.lower()) + optimizer_fn = self._get_optimizer_class_from_registry(self.optimizer.lower()) _optimizers_kwargs: Dict[str, Any] = {} elif isinstance(self.optimizer, Callable): optimizer_fn = self.optimizer _optimizers_kwargs: Dict[str, Any] = {} elif isinstance(self.optimizer, Tuple): - optimizer_fn: Callable = None - optimizer_key: str = self.optimizer[0] + if len(self.optimizer) != 2: + raise MisconfigurationException( + f"The tuple configuration of an optimizer input must be of length 2 with the first index" + f" containing a str from {self.available_optimizers()} and the second index containing the" + f" required keyword arguments to initialize the Optimizer." + ) - if not isinstance(optimizer_key, str): + if not isinstance(self.optimizer[0], str): raise TypeError( - f"PThe first value in scheduler argument tuple should be a string but got {type(optimizer_key)}." + f"The first value in optimizer argument tuple should be a string but got {type(self.optimizer[0])}." ) - optimizer_fn = self.optimizers.get(optimizer_key.lower()) + if not isinstance(self.optimizer[1], Dict): + raise TypeError( + f"The second value in optimizer argument tuple should be of dict type but got " + f"{type(self.optimizer[1])}." + ) + + optimizer_fn: Callable = self._get_optimizer_class_from_registry(self.optimizer[0]) _optimizers_kwargs: Dict[str, Any] = self.optimizer[1] else: raise TypeError( @@ -876,92 +893,117 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float] num_warmup_steps *= num_training_steps return round(num_warmup_steps) + def _get_lr_scheduler_class_from_registry(self, lr_scheduler_key: str) -> Dict[str, Any]: + if lr_scheduler_key.lower() not in self.available_lr_schedulers(): + raise KeyError( + f"Please provide a valid scheduler name and make sure it is registerd with the Scheduler registry." + f"\nUse `{self.__class__.__name__}.available_lr_schedulers()` to list the available schedulers." + f"\n>>> List of available LR Schedulers: {self.available_lr_schedulers()}." + ) + lr_scheduler_fn: Dict[str, Any] = self.lr_schedulers.get(lr_scheduler_key.lower(), with_metadata=True) + return deepcopy(lr_scheduler_fn) + def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: - if isinstance(self.lr_scheduler, str) or isinstance(self.lr_scheduler, Callable): - if isinstance(self.lr_scheduler, str) and self.lr_scheduler.lower() not in self.available_lr_schedulers(): - raise KeyError( - f"""Please provide a valid key and make sure it is registerd with the Scheduler registry. - Use `{self.__class__.__name__}.available_schedulers`.""" - ) + print(type(self.lr_scheduler)) + if isinstance(self.lr_scheduler, str): + lr_scheduler_data: Dict[str, Any] = self._get_lr_scheduler_class_from_registry(self.lr_scheduler) + lr_scheduler_fn = lr_scheduler_data.pop("fn") + lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None) + lr_scheduler_kwargs: Dict[str, Any] = {} + lr_scheduler_config = _get_default_scheduler_config() + for key, value in lr_scheduler_config.items(): + lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value + + elif isinstance(self.lr_scheduler, Callable): + lr_scheduler_data = {} + lr_scheduler_fn = self.lr_scheduler + lr_scheduler_metadata: Dict[str, Any] = None + lr_scheduler_kwargs: Dict[str, Any] = {} + lr_scheduler_config = _get_default_scheduler_config() - # Get values based in type. - if isinstance(self.lr_scheduler, str): - _lr_scheduler = self.lr_schedulers.get(self.lr_scheduler.lower(), with_metadata=True) - lr_scheduler_fn: Callable = _lr_scheduler["fn"] - lr_scheduler_metadata: Dict[str, Any] = _lr_scheduler["metadata"] - else: - lr_scheduler_fn: Callable = self.lr_scheduler + elif isinstance(self.lr_scheduler, Tuple): + if len(self.lr_scheduler) not in [2, 3]: + raise MisconfigurationException( + f"The tuple configuration of an scheduler input must be:\n" + f"1) Of length 2 with the first index containing a str from {self.available_lr_schedulers()} and" + f" the second index containing the required keyword arguments to initialize the LR Scheduler.\n" + f"2) Of length 3 with the first index containing a str from {self.available_lr_schedulers()} and" + f" the second index containing the required keyword arguments to initialize the LR Scheduler and" + f" the third index containing a Lightning scheduler configuration dictionary of the format" + f" {_get_default_scheduler_config()}. NOTE: Do not set the `scheduler` key in the" + f" lr_scheduler_config, it will overriden with an instance of the provided scheduler key." + ) - # Generate the output: could be a lr_scheduler object or a lr_scheduler config. - sched_output: Union[_LRScheduler, Dict[str, Any]] = lr_scheduler_fn(optimizer) + if not isinstance(self.lr_scheduler[0], (str, Callable)): + raise TypeError( + f"The first value in lr_scheduler argument tuple should be of type string or type Callable" + f" but got {type(self.lr_scheduler[0])}." + ) - # Create and/or update a lr_scheduler configuration - lr_scheduler_config = _get_default_scheduler_config() - if isinstance(sched_output, _LRScheduler): - lr_scheduler_config["scheduler"] = sched_output - if isinstance(self.lr_scheduler, str) and "interval" in lr_scheduler_metadata.keys(): - lr_scheduler_config["interval"] = lr_scheduler_metadata["interval"] - elif isinstance(sched_output, dict): - for key, value in sched_output.items(): - lr_scheduler_config[key] = value - else: - if isinstance(self.lr_scheduler, str): - message = "register a custom callable" - else: - message = "provide a callable" - raise MisconfigurationException( - f"Please {message} that outputs either an LR Scheduler or a scheduler condifguration." + if not isinstance(self.lr_scheduler[1], Dict): + raise TypeError( + f"The second value in lr_scheduler argument tuple should be of type dict but got" + f" {type(self.lr_scheduler[1])}." ) - return lr_scheduler_config + if len(self.lr_scheduler) == 3 and not isinstance(self.lr_scheduler[2], Dict): + raise TypeError( + f"The third value in lr_scheduler argument tuple should be of type dict but got" + f" {type(self.lr_scheduler[2])}." + ) - if not isinstance(self.lr_scheduler, Tuple): - raise TypeError("The scheduler arguments should be provided as a tuple.") + lr_scheduler_data: Dict[str, Any] = self._get_lr_scheduler_class_from_registry(self.lr_scheduler[0]) + lr_scheduler_fn = lr_scheduler_data.pop("fn") + lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None) + lr_scheduler_kwargs: Dict[str, Any] = self.lr_scheduler[1] + lr_scheduler_config = _get_default_scheduler_config() + for key, value in lr_scheduler_config.items(): + lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value + if len(self.lr_scheduler) == 3: + lr_scheduler_config.update(self.lr_scheduler[2]) - if not isinstance(self.lr_scheduler[0], str): + else: raise TypeError( - f"""The first value in scheduler argument tuple should be a string but got - {type(self.lr_scheduler[0])}.""" + f"`lr_scheduler` argument should be of type string or callable or tuple(string, dictionary)" + f" or tuple(string, dictionary, dictionary) but got {type(self.lr_scheduler)}." ) - # Separate the key and the kwargs. - lr_scheduler_key: str = self.lr_scheduler[0] - lr_scheduler_kwargs_and_config: Dict[str, Any] = self.lr_scheduler[1] - - # Get the default scheduler config. - lr_scheduler_config: Dict[str, Any] = _get_default_scheduler_config() - lr_scheduler_config["interval"] = None - - # Update scheduler config from the kwargs and pop the keys from the kwargs at the same time. - for config_key, config_value in lr_scheduler_config.items(): - lr_scheduler_config[config_key] = lr_scheduler_kwargs_and_config.pop(config_key, None) or config_value - - # Create a new copy of the kwargs. - lr_scheduler_kwargs = deepcopy(lr_scheduler_kwargs_and_config) - assert all(config_key not in lr_scheduler_kwargs.keys() for config_key in lr_scheduler_config.keys()) - - # Retreive the scheduler callable with metadata from the registry. - _lr_scheduler = self.lr_schedulers.get(lr_scheduler_key.lower(), with_metadata=True) - lr_scheduler_fn: Callable = _lr_scheduler["fn"] - lr_scheduler_metadata: Dict[str, Any] = _lr_scheduler["metadata"] - - # Make necessary adjustment to the kwargs based on the provider of the scheduler. - if "providers" in lr_scheduler_metadata.keys(): + # Providers part + if lr_scheduler_metadata is not None and "providers" in lr_scheduler_metadata.keys(): if lr_scheduler_metadata["providers"] == providers._HUGGINGFACE: - num_training_steps: int = self.get_num_training_steps() - num_warmup_steps: int = self._compute_warmup( - num_training_steps=num_training_steps, - num_warmup_steps=lr_scheduler_kwargs["num_warmup_steps"], - ) - lr_scheduler_kwargs["num_warmup_steps"] = num_warmup_steps - lr_scheduler_kwargs["num_training_steps"] = num_training_steps + if lr_scheduler_data["name"] != "constant_schedule": + num_training_steps: int = self.get_num_training_steps() + num_warmup_steps: int = self._compute_warmup( + num_training_steps=num_training_steps, + num_warmup_steps=lr_scheduler_kwargs["num_warmup_steps"], + ) + lr_scheduler_kwargs["num_warmup_steps"] = num_warmup_steps + if lr_scheduler_data["name"] != "constant_schedule_with_warmup": + lr_scheduler_kwargs["num_training_steps"] = num_training_steps + + # User can register a callable that returns a lr_scheduler_config + # 1) If return value is an instance of _LR_Scheduler -> Add to current config and return the config. + # 2) If return value is a dictionary, check for the lr_scheduler_config `only keys` and return the config. + lr_scheduler: Union[_LRScheduler, Dict[str, Any]] = lr_scheduler_fn(optimizer, **lr_scheduler_kwargs) + + if not isinstance(lr_scheduler, (_LRScheduler, Dict)): + raise MisconfigurationException( + f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler" + f" configuration with keys belonging to {list(_get_default_scheduler_config().keys())}." + ) - # Set the scheduler in the config. - lr_scheduler_config["scheduler"] = lr_scheduler_fn(optimizer, **lr_scheduler_kwargs) + if isinstance(lr_scheduler, Dict): + dummy_config = _get_default_scheduler_config() + if not all(config_key in dummy_config.keys() for config_key in lr_scheduler.keys()): + raise MisconfigurationException( + f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler" + f" configuration with keys belonging to {list(dummy_config.keys())}." + ) + # If all are present, return the config + return lr_scheduler - # Update the interval in sched config just in case it has NoneType. - if "interval" in lr_scheduler_metadata.keys(): - lr_scheduler_config["interval"] = lr_scheduler_config["interval"] or lr_scheduler_metadata["interval"] + # If `lr_scheduler` is not a Dict, then add it to the current config and return the config. + lr_scheduler_config["scheduler"] = lr_scheduler return lr_scheduler_config def _load_from_state_dict( diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py index 7a6a400110..c1eb736a18 100644 --- a/flash/graph/classification/model.py +++ b/flash/graph/classification/model.py @@ -107,7 +107,9 @@ def __init__( hidden_channels: Union[List[int], int] = 512, loss_fn: Callable = F.cross_entropy, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-3, model: torch.nn.Module = None, diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 60c6ec3577..19ec1f8295 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -82,7 +82,9 @@ def __init__( pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-3, multi_label: bool = False, diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index efbbda2ec8..57440eb8fe 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -58,7 +58,9 @@ def __init__( head: Optional[str] = "retinanet", pretrained: bool = True, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, learning_rate: float = 5e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 78a1a63074..e205eb5054 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -73,7 +73,9 @@ def __init__( backbone: str = "resnet", pretrained: bool = False, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, learning_rate: float = 1e-3, backbone_kwargs: Optional[Dict[str, Any]] = None, training_strategy_kwargs: Optional[Dict[str, Any]] = None, diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index edea34baf6..042d786dc7 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -69,7 +69,9 @@ def __init__( loss=None, metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, learning_rate: float = 1e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, preprocess: Optional[Preprocess] = None, diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index bd5986c7de..4ddf3ec942 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -56,7 +56,9 @@ def __init__( head: Optional[str] = "mask_rcnn", pretrained: bool = True, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 81ddb64ed6..8e2f3fdd9a 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -57,7 +57,9 @@ def __init__( head: Optional[str] = "keypoint_rcnn", pretrained: bool = True, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 7c6b6e322a..cb9f69a3b4 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -82,7 +82,9 @@ def __init__( pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-3, multi_label: bool = False, diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 1759f25015..789d539ee2 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -78,7 +78,9 @@ def __init__( style_layers: Union[Sequence[str], str] = ["relu1_2", "relu2_2", "relu3_3", "relu4_3"], style_weight: float = 1e10, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, learning_rate: float = 1e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 7224499886..20eed3c81b 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -70,7 +70,9 @@ def __init__( head: Optional[nn.Module] = None, loss_fn: Optional[Callable] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudObjectDetectorSerializer(), diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index c2c1e64506..a0fc8fe816 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -101,7 +101,9 @@ def __init__( head: Optional[nn.Module] = None, loss_fn: Optional[Callable] = torch.nn.functional.cross_entropy, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, multi_label: bool = False, diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index cb3ead134d..6578625a99 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -56,7 +56,9 @@ def __init__( embedding_sizes: List[Tuple[int, int]] = None, loss_fn: Callable = F.cross_entropy, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, multi_label: bool = False, diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index 6d47545db4..ff25f01467 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -56,7 +56,9 @@ def __init__( backbone_kwargs: Optional[Dict] = None, loss_fn: Optional[Callable] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, multi_label: bool = False, diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 2ddcce0a0b..b2d8a9cb72 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -62,7 +62,9 @@ def __init__( backbone: str = "prajjwal1/bert-medium", loss_fn: Optional[Callable] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, multi_label: bool = False, diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index d93c446241..264b938ee6 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -95,7 +95,9 @@ def __init__( backbone: str = "distilbert-base-uncased", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, enable_ort: bool = False, diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index f71c44bfe6..e124550de0 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -81,7 +81,9 @@ def __init__( backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, val_target_max_length: Optional[int] = None, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 3ba6cf8703..55e5c58385 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -53,7 +53,9 @@ def __init__( backbone: str = "sshleifer/distilbart-xsum-1-1", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-5, val_target_max_length: Optional[int] = None, diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index f01efaa0ab..32845d09da 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -50,7 +50,9 @@ def __init__( backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-5, val_target_max_length: Optional[int] = 128, diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 6e980f9360..e96810d611 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -108,7 +108,9 @@ def __init__( pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "SGD", - lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None, + lr_scheduler: Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] + ] = None, metrics: Union[Metric, Callable, Mapping, Sequence, None] = Accuracy(), learning_rate: float = 1e-3, head: Optional[Union[FunctionType, nn.Module]] = None, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index c904e1fb06..d1f2b9c478 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import math +from copy import deepcopy from itertools import chain from numbers import Number from pathlib import Path @@ -23,7 +24,6 @@ import pytest import pytorch_lightning as pl import torch -import torchmetrics from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor @@ -318,7 +318,7 @@ def custom_steplr_configuration_return_as_dict(optimizer): ("custom_steplr_configuration_return_as_dict", "step"), (functools.partial(torch.optim.lr_scheduler.StepLR, step_size=10), "epoch"), (("StepLR", {"step_size": 10}), "step"), - (("StepLR", {"step_size": 10, "interval": None}), "step"), + (("StepLR", {"step_size": 10}, {"interval": "epoch"}), "epoch"), ], ) def test_optimizers_and_schedulers(tmpdir, optim, sched, interval): @@ -375,14 +375,13 @@ def test_external_optimizers_torch_optimizer(tmpdir, optim): ) def test_external_schedulers_provider_hf_transformers(tmpdir, optim, sched): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) - task = ClassificationTask(model, optimizer=optim, lr_scheduler=sched, loss_fn=F.nll_loss) - trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) + task = ClassificationTask(model, optimizer=deepcopy(optim), lr_scheduler=deepcopy(sched), loss_fn=F.nll_loss) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=10, gpus=torch.cuda.device_count()) ds = DummyDataset() trainer.fit(task, train_dataloader=DataLoader(ds)) - optimizer, scheduler = task.configure_optimizers() - assert isinstance(optimizer[0], torch.optim.Adadelta) - assert isinstance(scheduler[0]["scheduler"], torch.optim.lr_scheduler.LambdaLR) + assert isinstance(trainer.optimizers[0], torch.optim.Adadelta) + assert isinstance(trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.LambdaLR) def test_errors_and_exceptions_optimizers_and_schedulers(): From ddb5d1f343e471db63217461c9cb38cefc535569 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 13 Oct 2021 18:21:04 +0530 Subject: [PATCH 15/22] Fix pre-commit ci review. --- flash/image/instance_segmentation/model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 97be34a77f..4ea6888d14 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -13,10 +13,7 @@ # limitations under the License. from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union -import torch from pytorch_lightning.utilities import rank_zero_info -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from flash.core.adapter import AdapterTask from flash.core.data.data_pipeline import DataPipeline From eb3aaec9c2c8509d9b50d32d3a9a23e27affb22f Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Thu, 14 Oct 2021 14:35:53 +0530 Subject: [PATCH 16/22] Add documentation for using the modified API and update CHANGELOG. --- CHANGELOG.md | 2 + docs/source/general/optimization.rst | 197 +++++++++++++++++++++++++++ docs/source/general/registry.rst | 5 +- docs/source/index.rst | 1 + flash/core/optimizers/schedulers.py | 2 + 5 files changed, 205 insertions(+), 2 deletions(-) create mode 100644 docs/source/general/optimization.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index 8038acec31..fcecd397b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759)) +- Optimizer and LR Scheduler registry are used to get the respective inputs to the Task using a string (or a callable). ([777](https://github.com/PyTorchLightning/lightning-flash/pull/777)) + ### Fixed - Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792)) diff --git a/docs/source/general/optimization.rst b/docs/source/general/optimization.rst new file mode 100644 index 0000000000..fd65244fa4 --- /dev/null +++ b/docs/source/general/optimization.rst @@ -0,0 +1,197 @@ + +.. _optimization: + +######################################## +Optimization (Optimizers and Schedulers) +######################################## + +Using optimizers and learning rate schedulers with Flash has become easier and cleaner than ever. + +With the use of :ref:`registry`, instantiation of an optimzer or a learning rate scheduler can done with just a string. + +Setting an optimizer to a task +============================== + +Each task has an inbuilt method :func:`~flash.core.model.Task.available_optimizers` which will list all the optimizers +registered with Flash. + + >>> from flash.core.classification import ClassificationTask + >>> ClassificationTask.available_optimizers() + ['adadelta', ..., 'sgd'] + +To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string. + +.. code-block:: python + + from flash.image import ImageClassifier + + model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4) + + +In order to customize specific parameters of the Optimizer, pass along a dictionary of kwargs with the string as a tuple. + +.. code-block:: python + + from flash.image import ImageClassifier + + model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4) + + +An alternative to customizing an optimizer using a tuple is to pass it as a callable. + +.. code-block:: python + + from functools import partial + from torch.optim import Adam + from flash.image import ImageClassifier + + model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4) + + +Setting a Learning Rate Scheduler +================================= + +Each task has an inbuilt method :func:`~flash.core.model.Task.available_lr_schedulers` which will list all the learning +rate schedulers registered with Flash. + + >>> from flash.core.classification import ClassificationTask + >>> ClassificationTask.available_lr_schedulers() + ['lambdalr', ..., 'cosineannealingwarmrestarts'] + +To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string. + +.. code-block:: python + + from flash.image import ImageClassifier + + model = ImageClassifier( + num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule" + ) + +.. note:: ``"constant_schedule"`` and a few other lr schedulers will be available only if you have installed the ``transformers`` library from Hugging Face. + + +In order to customize specific parameters of the LR Scheduler, pass along a dictionary of kwargs with the string as a tuple. + +.. code-block:: python + + from flash.image import ImageClassifier + + model = ImageClassifier( + num_classes=10, + backbone="resnet18", + optimizer="Adam", + learning_rate=1e-4, + lr_scheduler=("StepLR", {"step_size": 10}), + ) + + +An alternative to customizing the LR Scheduler using a tuple is to pass it as a callable. + +.. code-block:: python + + from functools import partial + from torch.optim.lr_scheduler import CyclicLR + from flash.image import ImageClassifier + + model = ImageClassifier( + num_classes=10, + backbone="resnet18", + optimizer="Adam", + learning_rate=1e-4, + lr_scheduler=partial(CyclicLR, step_size_up=1500, mode="exp_range", gamma=0.5), + ) + + +Additionally, the ``lr_scheduler`` parameter also accepts the Lightning Scheduler configuration which can be passed on using a tuple. + +The Lightning Scheduler configuration is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below. + +.. code-block:: python + + lr_scheduler_config = { + # REQUIRED: The scheduler instance + "scheduler": lr_scheduler, + # The unit of the scheduler's step size, could also be 'step'. + # 'epoch' updates the scheduler on epoch end whereas 'step' + # updates it after a optimizer update. + "interval": "epoch", + # How many epochs/steps should pass between calls to + # `scheduler.step()`. 1 corresponds to updating the learning + # rate after every epoch/step. + "frequency": 1, + # Metric to to monitor for schedulers like `ReduceLROnPlateau` + "monitor": "val_loss", + # If set to `True`, will enforce that the value specified 'monitor' + # is available when the scheduler is updated, thus stopping + # training if not found. If set to `False`, it will only produce a warning + "strict": True, + # If using the `LearningRateMonitor` callback to monitor the + # learning rate progress, this keyword can be used to specify + # a custom logged name + "name": None, + } + +When there are schedulers in which the ``.step()`` method is conditioned on a value, such as the ``torch.optim.lr_scheduler.ReduceLROnPlateau`` scheduler, +Flash requires that the Lightning Scheduler configuration contains the keyword ``"monitor"`` set to the metric name that the scheduler should be conditioned on. +Below is an example for this: + +.. code-block:: python + + from flash.image import ImageClassifier + + model = ImageClassifier( + num_classes=10, + backbone="resnet18", + optimizer="Adam", + learning_rate=1e-4, + lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}), + ) + + +.. note:: Do not set the ``"scheduler"`` key in the Lightning Scheduler configuration, it will overriden with an instance of the provided scheduler key. + + +Pre-Registering optimizers and scheduler recipes +================================================ + +Flash registry also provides the flexiblty of registering functions. This feature is also provided in the Optimizer and Scheduler registry. + +Using the ``optimizers`` and ``lr_schedulers`` decorator pertaining to each :class:`~flash.core.model.Task`, custom optimizer and LR scheduler recipes can be pre-registered. + +.. code-block:: python + + import torch + from flash.image import ImageClassifier + + + @ImageClassifier.lr_schedulers + def my_flash_steplr_recipe(optimizer): + return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) + + + model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe") + + +Provider specific requirements +============================== + +Schedulers +********** + +Certain LR Schedulers provided by Hugging Face require both ``num_training_steps`` and ``num_warmup_steps``. + +In order to use them in Flash, just provide ``num_warmup_steps`` as float between 0 and 1 which indicates the fraction of the training steps +that will be used as warmup steps. Flash's :class:`~flash.core.trainer.Trainer` will take care of computing the number of training steps and +number of warmup steps based on the flags that are set in the Trainer. + +.. code-block:: python + + from flash.image import ImageClassifier + + model = ImageClassifier( + backbone="resnet18", + num_classes=2, + optimizer="Adam", + lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}), + ) diff --git a/docs/source/general/registry.rst b/docs/source/general/registry.rst index c3d7a96806..0cf78aa552 100644 --- a/docs/source/general/registry.rst +++ b/docs/source/general/registry.rst @@ -1,9 +1,10 @@ + +.. _registry: + ######## Registry ######## -.. _registry: - ******************** Available Registries ******************** diff --git a/docs/source/index.rst b/docs/source/index.rst index 9fcace7dfe..120a3faada 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,6 +33,7 @@ Lightning Flash general/registry general/serve general/backbones + general/optimization .. toctree:: :maxdepth: 1 diff --git a/flash/core/optimizers/schedulers.py b/flash/core/optimizers/schedulers.py index 2e36c41ec4..b385dafff8 100644 --- a/flash/core/optimizers/schedulers.py +++ b/flash/core/optimizers/schedulers.py @@ -8,6 +8,7 @@ CosineAnnealingWarmRestarts, CyclicLR, MultiStepLR, + ReduceLROnPlateau, StepLR, ) @@ -24,6 +25,7 @@ if inspect.isclass(sched) and sched != _LRScheduler and issubclass(sched, _LRScheduler): schedulers.append(sched) +schedulers.append(ReduceLROnPlateau) for scheduler in schedulers: From 50c936a0eb85f47a8ea64f72c88730c9f407b286 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Thu, 14 Oct 2021 14:48:56 +0530 Subject: [PATCH 17/22] Update docstrings for all tasks. --- flash/audio/speech_recognition/model.py | 4 ++-- flash/core/model.py | 6 +++--- flash/graph/classification/model.py | 7 +++---- flash/image/classification/model.py | 4 ++-- flash/image/detection/model.py | 4 ++-- flash/image/embedding/model.py | 4 ++-- flash/image/face_detection/model.py | 4 ++-- flash/image/instance_segmentation/model.py | 4 ++-- flash/image/keypoint_detection/model.py | 4 ++-- flash/image/segmentation/model.py | 2 +- flash/image/style_transfer/model.py | 4 ++-- flash/pointcloud/detection/model.py | 4 ++-- flash/pointcloud/segmentation/model.py | 6 ++---- flash/tabular/classification/model.py | 3 ++- flash/template/classification/model.py | 6 ++---- flash/text/classification/model.py | 6 ++---- flash/text/question_answering/model.py | 6 ++---- flash/text/seq2seq/core/model.py | 6 ++---- flash/text/seq2seq/summarization/model.py | 6 ++---- flash/text/seq2seq/translation/model.py | 4 ++-- 20 files changed, 41 insertions(+), 53 deletions(-) diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index d259c4ae75..f75f7bca7a 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -38,9 +38,9 @@ class SpeechRecognition(Task): Args: backbone: Any speech recognition model from `HuggingFace/transformers `_. + learning_rate: Learning rate to use for training, defaults to ``1e-5``. optimizer: Optimizer to use for training. - lr_scheduler: The scheduler or scheduler class to use. - learning_rate: Learning rate to use for training, defaults to ``1e-3``. + lr_scheduler: The LR scheduler to use during training. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. """ diff --git a/flash/core/model.py b/flash/core/model.py index ed8728a26e..fb7a9bbc52 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -300,13 +300,13 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check Args: model: Model to use for the task. loss_fn: Loss function for training. - optimizer: Optimizer to use for training, defaults to :class:`torch.optim.Adam`. - lr_scheduler: The scheduler or scheduler class to use. + learning_rate: Learning rate to use for training, defaults to ``5e-5``. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inheriting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature `metric(preds,target)` and return a single scalar tensor. - learning_rate: Learning rate to use for training, defaults to ``5e-5``. deserializer: Either a single :class:`~flash.core.data.process.Deserializer` or a mapping of these to deserialize the input preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task. diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py index c1eb736a18..03e8d6dd76 100644 --- a/flash/graph/classification/model.py +++ b/flash/graph/classification/model.py @@ -89,11 +89,10 @@ class GraphClassifier(ClassificationTask): num_features: Number of columns in table (not including target column). num_classes: Number of classes to classify. hidden_channels: Hidden dimension sizes. - loss_fn: Loss function for training, defaults to cross entropy. - optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - lr_scheduler: The scheduler or scheduler class to use. - metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to `1e-3` + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. + metrics: Metrics to compute for training and evaluation. model: GraphNN used, defaults to BaseGraphModel. conv_cls: kind of convolution used in model, defaults to GCNConv """ diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 19ec1f8295..4bd7b46ad2 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -53,8 +53,8 @@ def fn_resnet(pretrained: bool = True): pretrained: A bool or string to specify the pretrained weights of the backbone, defaults to ``True`` which loads the default supervised pretrained weights. loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. - optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. - lr_scheduler: The scheduler or scheduler class to use. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inheriting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 57440eb8fe..58640f552e 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -37,8 +37,8 @@ class ObjectDetector(AdapterTask): loss: the function(s) to update the model with. Has no effect for torchvision detection models. metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. Changing this argument currently has no effect. - optimizer: The optimizer to use for training. Can either be the actual class or the class name. - lr_scheduler: The scheduler or scheduler class to use. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. Has no effect for custom models. learning_rate: The learning rate to use for training diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index e205eb5054..8423f7fc69 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -51,8 +51,8 @@ class ImageEmbedder(AdapterTask): 'moco_transform', or 'barlow_twins_transform'. backbone: VISSL backbone, defaults to ``resnet``. pretrained: Use a pretrained backbone, defaults to ``False``. - optimizer: Optimizer to use for training and finetuning, defaults to :class:`torch.optim.SGD`. - lr_scheduler: The scheduler or scheduler class to use. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. learning_rate: Learning rate to use for training, defaults to ``1e-3``. backbone_kwargs: arguments to be passed to VISSL backbones, i.e. ``vision_transformer`` and ``resnet``. training_strategy_kwargs: arguments passed to VISSL loss function, projection head and training hooks. diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index 042d786dc7..e7b8784d8b 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -55,8 +55,8 @@ class FaceDetector(Task): loss: the function(s) to update the model with. Has no effect for fastface models. metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. Changing this argument currently has no effect. - optimizer: The optimizer to use for training. Can either be the actual class or the class name. - lr_scheduler: The scheduler or scheduler class to use. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training """ diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 4ea6888d14..5718550bc7 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -41,8 +41,8 @@ class InstanceSegmentation(AdapterTask): loss: the function(s) to update the model with. Has no effect for torchvision detection models. metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. Changing this argument currently has no effect. - optimizer: The optimizer to use for training. Can either be the actual class or the class name. - lr_scheduler: The scheduler or scheduler class to use. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. Has no effect for custom models. learning_rate: The learning rate to use for training diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 8e2f3fdd9a..d90998f878 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -37,8 +37,8 @@ class KeypointDetector(AdapterTask): loss: the function(s) to update the model with. Has no effect for torchvision detection models. metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. Changing this argument currently has no effect. - optimizer: The optimizer to use for training. Can either be the actual class or the class name. - lr_scheduler: The scheduler or scheduler class to use. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. Has no effect for custom models. learning_rate: The learning rate to use for training diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index cb9f69a3b4..dad4ab5b74 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -53,7 +53,7 @@ class SemanticSegmentation(ClassificationTask): pretrained: Use a pretrained backbone. loss_fn: Loss function for training. optimizer: Optimizer to use for training. - lr_scheduler: The scheduler or scheduler class to use. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 789d539ee2..74020f94a0 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -58,8 +58,8 @@ class StyleTransfer(Task): content_weight: The weight associated with the content loss. A lower value will lose content over style. style_layers: Layers from the backbone to derive the style loss from. style_weight: The weight associated with the style loss. A lower value will lose style over content. - optimizer: Optimizer to use for training the model. - lr_scheduler: Scheduler to use for training the model. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. learning_rate: Learning rate to use for training, defaults to ``1e-3``. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. """ diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 20eed3c81b..617cfffc2b 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -47,8 +47,8 @@ class PointCloudObjectDetector(Task): backbone_kwargs: Any additional kwargs to pass to the backbone constructor. loss_fn: The loss function to use. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. - optimizer: The optimizer or optimizer class to use. - lr_scheduler: The scheduler or scheduler class to use. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index a0fc8fe816..beb16073ba 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -78,10 +78,8 @@ class PointCloudSegmentation(ClassificationTask): backbone_kwargs: Any additional kwargs to pass to the backbone constructor. loss_fn: The loss function to use. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. - optimizer: The optimizer or optimizer class to use. - - lr_scheduler: The scheduler or scheduler class to use. - + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 6578625a99..cf64e4b2da 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -35,7 +35,8 @@ class TabularClassifier(ClassificationTask): num_classes: Number of classes to classify. embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings. loss_fn: Loss function for training, defaults to cross entropy. - optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index ff25f01467..afc05decb4 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -35,10 +35,8 @@ class TemplateSKLearnClassifier(ClassificationTask): backbone_kwargs: Any additional kwargs to pass to the backbone constructor. loss_fn: The loss function to use. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. - optimizer: The optimizer or optimizer class to use. - - lr_scheduler: The scheduler or scheduler class to use. - + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index b2d8a9cb72..604295a195 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -38,10 +38,8 @@ class TextClassifier(ClassificationTask): Args: num_classes: Number of classes to classify. backbone: A model to use to compute text features can be any BERT model from HuggingFace/transformersimage . - optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - - lr_scheduler: The scheduler or scheduler class to use. - + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 264b938ee6..206ec35dc0 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -67,10 +67,8 @@ class QuestionAnsweringTask(Task): Args: backbone: backbone model to use for the task. loss_fn: Loss function for training. - optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - - lr_scheduler: The scheduler or scheduler class to use. - + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `3e-4` diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index e124550de0..67e3effe85 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -61,10 +61,8 @@ class Seq2SeqTask(Task): Args: loss_fn: Loss function for training - optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - - lr_scheduler: The scheduler or scheduler class to use. - + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Changing this argument currently has no effect learning_rate: Learning rate to use for training, defaults to `3e-4` val_target_max_length: Maximum length of targets in validation. Defaults to `128` diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 55e5c58385..65ca10e55e 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -34,10 +34,8 @@ class SummarizationTask(Seq2SeqTask): Args: backbone: backbone model to use for the task. loss_fn: Loss function for training. - optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - - lr_scheduler: The scheduler or scheduler class to use. - + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `3e-4` diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 32845d09da..5ed9c0327e 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -33,8 +33,8 @@ class TranslationTask(Seq2SeqTask): Args: backbone: backbone model to use for the task. loss_fn: Loss function for training. - optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. - lr_scheduler: The scheduler or scheduler class to use. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Defauls to calculating the BLEU metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `1e-5` From 5dfbeae83c21f333def2fed16ce81c2c7e35c43e Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Thu, 14 Oct 2021 14:56:48 +0530 Subject: [PATCH 18/22] Fix mistake in my CHANGELOG update. --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fcecd397b2..956e7826af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759)) -- Optimizer and LR Scheduler registry are used to get the respective inputs to the Task using a string (or a callable). ([777](https://github.com/PyTorchLightning/lightning-flash/pull/777)) +- Optimizer and LR Scheduler registry are used to get the respective inputs to the Task using a string (or a callable). ([#777](https://github.com/PyTorchLightning/lightning-flash/pull/777)) ### Fixed From 93dbe672896400fb50ff5f9e791a8960d9af17c9 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Thu, 14 Oct 2021 15:13:20 +0530 Subject: [PATCH 19/22] Removed optimizer old that was commented code. --- flash/image/keypoint_detection/model.py | 1 - flash/image/segmentation/model.py | 1 - flash/image/style_transfer/model.py | 1 - flash/pointcloud/detection/model.py | 1 - flash/pointcloud/segmentation/model.py | 1 - flash/tabular/classification/model.py | 1 - flash/template/classification/model.py | 1 - flash/text/classification/model.py | 1 - flash/text/question_answering/model.py | 1 - flash/text/seq2seq/core/model.py | 1 - flash/text/seq2seq/summarization/model.py | 1 - flash/text/seq2seq/translation/model.py | 1 - flash/video/classification/model.py | 1 - 13 files changed, 13 deletions(-) diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index d90998f878..664348a778 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -81,7 +81,6 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, serializer=serializer or Preds(), ) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index dad4ab5b74..59017d8ff7 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -105,7 +105,6 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 74020f94a0..353c333389 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -110,7 +110,6 @@ def __init__( model=model, loss_fn=perceptual_loss, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, learning_rate=learning_rate, serializer=serializer, diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 617cfffc2b..418be30a44 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -85,7 +85,6 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index beb16073ba..fc091e04e7 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -116,7 +116,6 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index cf64e4b2da..0f80ea241a 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -82,7 +82,6 @@ def __init__( model=model, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index afc05decb4..10fbb91f89 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -66,7 +66,6 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 604295a195..f96ea79e95 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -82,7 +82,6 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 206ec35dc0..34039d8594 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -114,7 +114,6 @@ def __init__( super().__init__( loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 67e3effe85..bd3cfc9c63 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -96,7 +96,6 @@ def __init__( super().__init__( loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 65ca10e55e..1e33a38db8 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -67,7 +67,6 @@ def __init__( backbone=backbone, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 5ed9c0327e..5e8af986b0 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -66,7 +66,6 @@ def __init__( backbone=backbone, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index e96810d611..86152f914f 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -120,7 +120,6 @@ def __init__( model=None, loss_fn=loss_fn, optimizer=optimizer, - # optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, From 5e76ea35da9217d2b065872a229ae688f0f69c03 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Thu, 14 Oct 2021 20:55:07 +0530 Subject: [PATCH 20/22] Fix dependency version for failing tests on text type data, module - datasets. --- requirements/datatype_text.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index 9c188058b6..75611db808 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -2,4 +2,4 @@ rouge-score>=0.0.4 sentencepiece>=0.1.95 filelock transformers>=4.5 -datasets>=1.8 +datasets>=1.8,<1.13 From c49d70a2f32ab7855f80595388e5f066f3e24350 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Fri, 15 Oct 2021 22:36:23 +0530 Subject: [PATCH 21/22] Changes from review - Fix docs, Add test, Clean up certian parts of the code. --- docs/source/general/optimization.rst | 4 +-- flash/audio/speech_recognition/model.py | 12 +++---- flash/core/model.py | 42 +++++++++++++--------- flash/core/optimizers/schedulers.py | 2 ++ flash/core/utilities/types.py | 18 ++++++++++ flash/graph/classification/model.py | 15 ++++---- flash/image/classification/model.py | 17 ++++----- flash/image/detection/model.py | 12 +++---- flash/image/embedding/model.py | 9 +++-- flash/image/face_detection/model.py | 29 ++++++++------- flash/image/instance_segmentation/model.py | 12 +++---- flash/image/keypoint_detection/model.py | 12 +++---- flash/image/segmentation/model.py | 28 +++++++++------ flash/image/style_transfer/model.py | 12 +++---- flash/pointcloud/detection/model.py | 16 ++++----- flash/pointcloud/segmentation/model.py | 16 ++++----- flash/tabular/classification/model.py | 15 ++++---- flash/template/classification/model.py | 17 ++++----- flash/text/classification/model.py | 17 ++++----- flash/text/question_answering/model.py | 12 +++---- flash/text/seq2seq/core/model.py | 14 ++++---- flash/text/seq2seq/summarization/model.py | 14 ++++---- flash/text/seq2seq/translation/model.py | 15 ++++---- flash/video/classification/model.py | 18 +++++----- tests/core/test_model.py | 4 +++ 25 files changed, 194 insertions(+), 188 deletions(-) create mode 100644 flash/core/utilities/types.py diff --git a/docs/source/general/optimization.rst b/docs/source/general/optimization.rst index fd65244fa4..afe1bee308 100644 --- a/docs/source/general/optimization.rst +++ b/docs/source/general/optimization.rst @@ -12,7 +12,7 @@ With the use of :ref:`registry`, instantiation of an optimzer or a learning rate Setting an optimizer to a task ============================== -Each task has an inbuilt method :func:`~flash.core.model.Task.available_optimizers` which will list all the optimizers +Each task has a built-in method :func:`~flash.core.model.Task.available_optimizers` which will list all the optimizers registered with Flash. >>> from flash.core.classification import ClassificationTask @@ -51,7 +51,7 @@ An alternative to customizing an optimizer using a tuple is to pass it as a call Setting a Learning Rate Scheduler ================================= -Each task has an inbuilt method :func:`~flash.core.model.Task.available_lr_schedulers` which will list all the learning +Each task has a built-in method :func:`~flash.core.model.Task.available_lr_schedulers` which will list all the learning rate schedulers registered with Flash. >>> from flash.core.classification import ClassificationTask diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index f75f7bca7a..18f215b395 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -13,7 +13,7 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union +from typing import Any, Dict import torch import torch.nn as nn @@ -21,11 +21,11 @@ from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding from flash.audio.speech_recognition.data import SpeechRecognitionBackboneState -from flash.core.data.process import Serializer from flash.core.data.states import CollateFn from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _AUDIO_AVAILABLE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE if _AUDIO_AVAILABLE: from transformers import Wav2Vec2Processor @@ -51,12 +51,10 @@ class SpeechRecognition(Task): def __init__( self, backbone: str = "facebook/wav2vec2-base-960h", - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-5, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings diff --git a/flash/core/model.py b/flash/core/model.py index 62e245bdf8..5ebc879f98 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -17,7 +17,7 @@ from abc import ABCMeta from copy import deepcopy from importlib import import_module -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union import pytorch_lightning as pl import torch @@ -53,6 +53,17 @@ from flash.core.utilities.apply_func import get_callable_dict from flash.core.utilities.imports import requires from flash.core.utilities.stages import RunningStage +from flash.core.utilities.types import ( + DESERIALIZER_TYPE, + LOSS_FN_TYPE, + LR_SCHEDULER_TYPE, + METRICS_TYPE, + MODEL_TYPE, + OPTIMIZER_TYPE, + POSTPROCESS_TYPE, + PREPROCESS_TYPE, + SERIALIZER_TYPE, +) class ModuleWrapperBase: @@ -322,18 +333,16 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check def __init__( self, - model: Optional[nn.Module] = None, - loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, + model: MODEL_TYPE = None, + loss_fn: LOSS_FN_TYPE = None, learning_rate: float = 5e-5, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, + deserializer: DESERIALIZER_TYPE = None, + preprocess: PREPROCESS_TYPE = None, + postprocess: POSTPROCESS_TYPE = None, + serializer: SERIALIZER_TYPE = None, ): super().__init__() if model is not None: @@ -512,10 +521,10 @@ def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_ """Implement how optimizer and optionally learning rate schedulers should be configured.""" if isinstance(self.optimizer, str): optimizer_fn = self._get_optimizer_class_from_registry(self.optimizer.lower()) - _optimizers_kwargs: Dict[str, Any] = {} + optimizers_kwargs: Dict[str, Any] = {"lr": self.learning_rate} elif isinstance(self.optimizer, Callable): optimizer_fn = self.optimizer - _optimizers_kwargs: Dict[str, Any] = {} + optimizers_kwargs: Dict[str, Any] = {"lr": self.learning_rate} elif isinstance(self.optimizer, Tuple): if len(self.optimizer) != 2: raise MisconfigurationException( @@ -536,7 +545,8 @@ def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_ ) optimizer_fn: Callable = self._get_optimizer_class_from_registry(self.optimizer[0]) - _optimizers_kwargs: Dict[str, Any] = self.optimizer[1] + optimizers_kwargs: Dict[str, Any] = self.optimizer[1] + optimizers_kwargs["lr"] = self.learning_rate else: raise TypeError( f"""Optimizer should be of type string or callable or tuple(string, dictionary) @@ -544,7 +554,7 @@ def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_ ) model_parameters = filter(lambda p: p.requires_grad, self.parameters()) - optimizer: Optimizer = optimizer_fn(model_parameters, lr=self.learning_rate, **_optimizers_kwargs) + optimizer: Optimizer = optimizer_fn(model_parameters, **optimizers_kwargs) if self.lr_scheduler is not None: return [optimizer], [self._instantiate_lr_scheduler(optimizer)] return optimizer diff --git a/flash/core/optimizers/schedulers.py b/flash/core/optimizers/schedulers.py index b385dafff8..e264795f4b 100644 --- a/flash/core/optimizers/schedulers.py +++ b/flash/core/optimizers/schedulers.py @@ -25,6 +25,8 @@ if inspect.isclass(sched) and sched != _LRScheduler and issubclass(sched, _LRScheduler): schedulers.append(sched) + +# Adding `ReduceLROnPlateau` separately as it is subclassed from `object` and not `_LRScheduler`. schedulers.append(ReduceLROnPlateau) diff --git a/flash/core/utilities/types.py b/flash/core/utilities/types.py new file mode 100644 index 0000000000..8138db88b5 --- /dev/null +++ b/flash/core/utilities/types.py @@ -0,0 +1,18 @@ +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union + +from torch import nn +from torchmetrics import Metric + +from flash.core.data.process import Deserializer, Postprocess, Preprocess, Serializer + +MODEL_TYPE = Optional[nn.Module] +LOSS_FN_TYPE = Optional[Union[Callable, Mapping, Sequence]] +OPTIMIZER_TYPE = Union[str, Callable, Tuple[str, Dict[str, Any]]] +LR_SCHEDULER_TYPE = Optional[ + Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] +] +METRICS_TYPE = Union[Metric, Mapping, Sequence, None] +DESERIALIZER_TYPE = Optional[Union[Deserializer, Mapping[str, Deserializer]]] +PREPROCESS_TYPE = Optional[Preprocess] +POSTPROCESS_TYPE = Optional[Postprocess] +SERIALIZER_TYPE = Optional[Union[Serializer, Mapping[str, Serializer]]] diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py index 03e8d6dd76..265703ea22 100644 --- a/flash/graph/classification/model.py +++ b/flash/graph/classification/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, List, Type, Union import torch from torch import nn @@ -20,6 +20,7 @@ from flash.core.classification import ClassificationTask from flash.core.utilities.imports import _GRAPH_AVAILABLE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE if _GRAPH_AVAILABLE: from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool, MessagePassing @@ -104,14 +105,12 @@ def __init__( num_features: int, num_classes: int, hidden_channels: Union[List[int], int] = 512, - loss_fn: Callable = F.cross_entropy, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Callable, Mapping, Sequence, None] = None, - learning_rate: float = 1e-3, model: torch.nn.Module = None, + loss_fn: LOSS_FN_TYPE = F.cross_entropy, + learning_rate: float = 1e-3, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, conv_cls: Type[MessagePassing] = GCNConv, **conv_kwargs ): diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 4bd7b46ad2..4d23f9aa3a 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import FunctionType -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn -from torchmetrics import Metric from flash.core.classification import ClassificationAdapterTask, Labels -from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.image.classification.adapters import TRAINING_STRATEGIES from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES @@ -80,15 +79,13 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: Optional[Union[FunctionType, nn.Module]] = None, pretrained: Union[bool, str] = True, - loss_fn: Optional[Callable] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-3, multi_label: bool = False, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, training_strategy: Optional[str] = "default", training_strategy_kwargs: Optional[Dict[str, Any]] = None, ): diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 58640f552e..6a7cd8aff3 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask -from flash.core.data.process import Serializer from flash.core.data.serialization import Preds from flash.core.registry import FlashRegistry +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.image.detection.backbones import OBJECT_DETECTION_HEADS @@ -57,12 +57,10 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "retinanet", pretrained: bool = True, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-3, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 8423f7fc69..ebc6d73d32 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask from flash.core.data.data_source import DefaultDataKeys @@ -20,6 +20,7 @@ from flash.core.data.transforms import ApplyToKeys from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _VISSL_AVAILABLE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE if _VISSL_AVAILABLE: import classy_vision @@ -72,10 +73,8 @@ def __init__( pretraining_transform: str, backbone: str = "resnet", pretrained: bool = False, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-3, backbone_kwargs: Optional[Dict[str, Any]] = None, training_strategy_kwargs: Optional[Dict[str, Any]] = None, diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index e7b8784d8b..042c417848 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -11,17 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List import pytorch_lightning as pl import torch -from torch import nn from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Preprocess, Serializer +from flash.core.data.process import Serializer from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task from flash.core.utilities.imports import _FASTFACE_AVAILABLE +from flash.core.utilities.types import ( + LOSS_FN_TYPE, + LR_SCHEDULER_TYPE, + METRICS_TYPE, + OPTIMIZER_TYPE, + PREPROCESS_TYPE, + SERIALIZER_TYPE, +) from flash.image.face_detection.backbones import FACE_DETECTION_BACKBONES from flash.image.face_detection.data import FaceDetectionPreprocess @@ -66,15 +73,13 @@ def __init__( self, model: str = "lffd_slim", pretrained: bool = True, - loss=None, - metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, + loss_fn: LOSS_FN_TYPE = None, + metrics: METRICS_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-4, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, - preprocess: Optional[Preprocess] = None, + serializer: SERIALIZER_TYPE = None, + preprocess: PREPROCESS_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() @@ -86,7 +91,7 @@ def __init__( super().__init__( model=model, - loss_fn=loss, + loss_fn=loss_fn, metrics=metrics or {"AP": ff.metric.AveragePrecision()}, # TODO: replace with torch metrics MAP learning_rate=learning_rate, optimizer=optimizer, diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 5718550bc7..ae68668768 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, List, Optional from pytorch_lightning.utilities import rank_zero_info from flash.core.adapter import AdapterTask from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.process import Serializer from flash.core.data.serialization import Preds from flash.core.registry import FlashRegistry +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS from flash.image.instance_segmentation.data import InstanceSegmentationPostProcess, InstanceSegmentationPreprocess @@ -59,12 +59,10 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "mask_rcnn", pretrained: bool = True, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 664348a778..306b334d12 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask -from flash.core.data.process import Serializer from flash.core.data.serialization import Preds from flash.core.registry import FlashRegistry +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS @@ -56,12 +56,10 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "keypoint_rcnn", pretrained: bool = True, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 59017d8ff7..b0b293ad6b 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -11,19 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch from torch import nn from torch.nn import functional as F -from torchmetrics import IoU, Metric +from torchmetrics import IoU from flash.core.classification import ClassificationTask from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Postprocess, Serializer +from flash.core.data.process import Postprocess from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _KORNIA_AVAILABLE from flash.core.utilities.isinstance import _isinstance +from flash.core.utilities.types import ( + LOSS_FN_TYPE, + LR_SCHEDULER_TYPE, + METRICS_TYPE, + OPTIMIZER_TYPE, + POSTPROCESS_TYPE, + SERIALIZER_TYPE, +) from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS from flash.image.segmentation.serialization import SegmentationLabels @@ -80,16 +88,14 @@ def __init__( head: str = "fpn", head_kwargs: Optional[Dict] = None, pretrained: Union[bool, str] = True, - loss_fn: Optional[Callable] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-3, multi_label: bool = False, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, - postprocess: Optional[Postprocess] = None, + serializer: SERIALIZER_TYPE = None, + postprocess: POSTPROCESS_TYPE = None, ) -> None: if metrics is None: metrics = IoU(num_classes=num_classes) diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 353c333389..a03575a64a 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Union +from typing import Any, cast, List, NoReturn, Optional, Sequence, Tuple, Union import torch from torch import nn from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Serializer from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.image.style_transfer import STYLE_TRANSFER_BACKBONES if _IMAGE_AVAILABLE: @@ -77,12 +77,10 @@ def __init__( content_weight: float = 1e5, style_layers: Union[Sequence[str], str] = ["relu1_2", "relu2_2", "relu3_3", "relu4_3"], style_weight: float = 1e10, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-3, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, ): self.save_hyperparameters(ignore="style_image") diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 418be30a44..b35604cae3 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -import torchmetrics from torch import nn from torch.utils.data import DataLoader, Sampler @@ -27,6 +26,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.apply_func import get_callable_dict from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.pointcloud.detection.backbones import POINTCLOUD_OBJECT_DETECTION_BACKBONES __FILE_EXAMPLE__ = "pointcloud_detection" @@ -68,14 +68,12 @@ def __init__( backbone: Union[str, Tuple[nn.Module, int]] = "pointpillars_kitti", backbone_kwargs: Optional[Dict] = None, head: Optional[nn.Module] = None, - loss_fn: Optional[Callable] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudObjectDetectorSerializer(), + serializer: SERIALIZER_TYPE = PointCloudObjectDetectorSerializer(), lambda_loss_cls: float = 1.0, lambda_loss_bbox: float = 1.0, lambda_loss_dir: float = 1.0, diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index fc091e04e7..e8578b586c 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -import torchmetrics from pytorch_lightning import Callback, LightningModule from torch import nn from torch.nn import functional as F @@ -30,6 +29,7 @@ from flash.core.finetuning import BaseFinetuning from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES if _POINTCLOUD_AVAILABLE: @@ -97,15 +97,13 @@ def __init__( backbone: Union[str, Tuple[nn.Module, int]] = "RandLANet", backbone_kwargs: Optional[Dict] = None, head: Optional[nn.Module] = None, - loss_fn: Optional[Callable] = torch.nn.functional.cross_entropy, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = torch.nn.functional.cross_entropy, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudSegmentationSerializer(), + serializer: SERIALIZER_TYPE = PointCloudSegmentationSerializer(), ): import flash diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 0f80ea241a..2ed21da6e0 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple import torch from torch.nn import functional as F -from torchmetrics import Metric from flash.core.classification import ClassificationTask, Probabilities from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Serializer from flash.core.utilities.imports import _TABULAR_AVAILABLE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE if _TABULAR_AVAILABLE: from pytorch_tabnet.tab_network import TabNet @@ -56,14 +55,12 @@ def __init__( num_classes: int, embedding_sizes: List[Tuple[int, int]] = None, loss_fn: Callable = F.cross_entropy, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, **tabnet_kwargs, ): self.save_hyperparameters() diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index 10fbb91f89..3350972567 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch -import torchmetrics from torch import nn from flash.core.classification import ClassificationTask, Labels from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.template.classification.backbones import TEMPLATE_BACKBONES @@ -52,15 +51,13 @@ def __init__( num_classes: int, backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128", backbone_kwargs: Optional[Dict] = None, - loss_fn: Optional[Callable] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, ): super().__init__( model=None, diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index f96ea79e95..c718d48c2b 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -13,16 +13,15 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List import torch from pytorch_lightning import Callback -from torchmetrics import Metric from flash.core.classification import ClassificationTask, Labels -from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE from flash.text.classification.backbones import TEXT_CLASSIFIER_BACKBONES from flash.text.ort_callback import ORTCallback @@ -58,15 +57,13 @@ def __init__( self, num_classes: int, backbone: str = "prajjwal1/bert-medium", - loss_fn: Optional[Callable] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + serializer: SERIALIZER_TYPE = None, enable_ort: bool = False, ): self.save_hyperparameters() diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 34039d8594..7087239c18 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -19,14 +19,13 @@ import collections import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Union import numpy as np import torch from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_info from torch import Tensor -from torchmetrics import Metric from flash.core.data.data_source import DefaultDataKeys from flash.core.finetuning import FlashBaseFinetuning @@ -34,6 +33,7 @@ from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.ort_callback import ORTCallback from flash.text.question_answering.finetuning import QuestionAnsweringFreezeEmbeddings from flash.text.seq2seq.core.metrics import RougeMetric @@ -92,11 +92,9 @@ def __init__( self, backbone: str = "distilbert-base-uncased", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 5e-5, enable_ort: bool = False, n_best_size: int = 20, diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index bd3cfc9c63..060ef5fab1 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -13,19 +13,19 @@ # limitations under the License. import os import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional import torch from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_info from torch import Tensor -from torchmetrics import Metric from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.ort_callback import ORTCallback from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings @@ -77,12 +77,10 @@ class Seq2SeqTask(Task): def __init__( self, backbone: str = "t5-small", - loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 5e-5, val_target_max_length: Optional[int] = None, num_beams: Optional[int] = None, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 1e33a38db8..99bf064ad7 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional import torch -from torchmetrics import Metric +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.seq2seq.core.metrics import RougeMetric from flash.text.seq2seq.core.model import Seq2SeqTask @@ -49,12 +49,10 @@ class SummarizationTask(Seq2SeqTask): def __init__( self, backbone: str = "sshleifer/distilbart-xsum-1-1", - loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-5, val_target_max_length: Optional[int] = None, num_beams: Optional[int] = 4, diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 5e8af986b0..f47d4f6b08 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union - -from torchmetrics import Metric +from typing import Any, Dict, List, Optional +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.seq2seq.core.metrics import BLEUScore from flash.text.seq2seq.core.model import Seq2SeqTask @@ -48,12 +47,10 @@ class TranslationTask(Seq2SeqTask): def __init__( self, backbone: str = "t5-small", - loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + loss_fn: LOSS_FN_TYPE = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, learning_rate: float = 1e-5, val_target_max_length: Optional[int] = 128, num_beams: Optional[int] = 4, diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 86152f914f..bddf95f75c 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import FunctionType -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch from pytorch_lightning import LightningModule @@ -23,15 +23,15 @@ from torch.nn import functional as F from torch.optim import Optimizer from torch.utils.data import DistributedSampler -from torchmetrics import Accuracy, Metric +from torchmetrics import Accuracy import flash from flash.core.classification import ClassificationTask, Labels from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE from flash.core.utilities.providers import _PYTORCHVIDEO +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE _VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones") @@ -106,15 +106,13 @@ def __init__( backbone: Union[str, nn.Module] = "x3d_xs", backbone_kwargs: Optional[Dict] = None, pretrained: bool = True, - loss_fn: Callable = F.cross_entropy, - optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "SGD", - lr_scheduler: Optional[ - Union[str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]]] - ] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = Accuracy(), + loss_fn: LOSS_FN_TYPE = F.cross_entropy, + optimizer: OPTIMIZER_TYPE = "SGD", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = Accuracy(), learning_rate: float = 1e-3, head: Optional[Union[FunctionType, nn.Module]] = None, - serializer: Optional[Serializer] = None, + serializer: SERIALIZER_TYPE = None, ): super().__init__( model=None, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d1f2b9c478..0e68344bb5 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -401,6 +401,10 @@ def test_errors_and_exceptions_optimizers_and_schedulers(): ) task.configure_optimizers() + with pytest.raises(TypeError): + task = ClassificationTask(model, optimizer=("Adam", ["non", "dict", "type"]), lr_scheduler=None) + task.configure_optimizers() + with pytest.raises(KeyError): task = ClassificationTask(model, optimizer="Adam", lr_scheduler="not_a_valid_key") task.configure_optimizers() From 35f383408a39b4f5c3e7f96720581c1086c3c0ea Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 18 Oct 2021 20:35:50 +0530 Subject: [PATCH 22/22] Remove debug print statements. --- flash/core/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flash/core/model.py b/flash/core/model.py index 5ebc879f98..051ab83b54 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -914,7 +914,6 @@ def _get_lr_scheduler_class_from_registry(self, lr_scheduler_key: str) -> Dict[s return deepcopy(lr_scheduler_fn) def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: - print(type(self.lr_scheduler)) if isinstance(self.lr_scheduler, str): lr_scheduler_data: Dict[str, Any] = self._get_lr_scheduler_class_from_registry(self.lr_scheduler) lr_scheduler_fn = lr_scheduler_data.pop("fn")