Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

PoC: Revamp optimizer and scheduler experience using registries #777

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2f46b93
Change optimizer Callables alone and scheduler to support Callables a…
karthikrangasai Sep 15, 2021
caefe68
Add Optimizer Registry and Update __init__ for all tasks.
karthikrangasai Sep 15, 2021
93bc1b5
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 20, 2021
7ea53a2
Revamp scheduler parameter to use str, Callable, str with params.
karthikrangasai Sep 22, 2021
e95a209
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 26, 2021
4cf6cdd
Updated _instantiate_scheduler method to handle providers. Added supp…
karthikrangasai Sep 26, 2021
440aef2
wip
tchaton Sep 27, 2021
094b690
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 29, 2021
06e7722
Updated scheduler parameter to take input as type Tuple[str, Dict[str…
karthikrangasai Sep 29, 2021
8ab54bd
Update naming of scheduler parameter to lr_scheduler.
karthikrangasai Sep 29, 2021
617e53a
Update optimizer and lr_scheduler parameter across all tasks.
karthikrangasai Sep 29, 2021
dd5615e
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 29, 2021
7a3029b
Updated optimizer registration code to compare with optimizer types a…
karthikrangasai Sep 29, 2021
d36c451
Added tests for Errors and Exceptions.
karthikrangasai Sep 29, 2021
061454b
Update README with examples on using the API.
karthikrangasai Sep 30, 2021
c611aa8
Update skipif condition only to check for transformers library instea…
karthikrangasai Sep 30, 2021
64cedf3
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 1, 2021
e158802
Update newly added Face Detection Task.
karthikrangasai Oct 1, 2021
c8cb598
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 4, 2021
fcb3916
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 7, 2021
eda81ae
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 13, 2021
20eacaf
Changes from code review, Add new input method to lr_scheduler parame…
karthikrangasai Oct 13, 2021
87cf563
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 13, 2021
ddb5d1f
Fix pre-commit ci review.
karthikrangasai Oct 13, 2021
eb3aaec
Add documentation for using the modified API and update CHANGELOG.
karthikrangasai Oct 14, 2021
50c936a
Update docstrings for all tasks.
karthikrangasai Oct 14, 2021
5dfbeae
Fix mistake in my CHANGELOG update.
karthikrangasai Oct 14, 2021
93dbe67
Removed optimizer old that was commented code.
karthikrangasai Oct 14, 2021
42e3bf4
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 14, 2021
5e76ea3
Fix dependency version for failing tests on text type data, module - …
karthikrangasai Oct 14, 2021
ec348bf
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 15, 2021
c49d70a
Changes from review - Fix docs, Add test, Clean up certian parts of t…
karthikrangasai Oct 15, 2021
66c30bc
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 18, 2021
35f3834
Remove debug print statements.
karthikrangasai Oct 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
import warnings
from typing import Any, Dict, Mapping, Optional, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -54,8 +54,8 @@ class SpeechRecognition(Task):
def __init__(
self,
backbone: str = "facebook/wav2vec2-base-960h",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam",
# optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
learning_rate: float = 1e-5,
Expand All @@ -71,7 +71,7 @@ def __init__(
super().__init__(
model=model,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
# optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
learning_rate=learning_rate,
Expand Down
95 changes: 65 additions & 30 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
SerializerMapping,
)
from flash.core.data.properties import ProcessState
from flash.core.optimizers import _OPTIMIZERS_REGISTRY, _SCHEDULERS_REGISTRY
from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.serve import Composition
from flash.core.utilities import providers
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import requires

Expand Down Expand Up @@ -303,6 +304,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check
postprocess: :class:`~flash.core.data.process.Postprocess` to use as the default for this task.
"""

optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY
schedulers: FlashRegistry = _SCHEDULERS_REGISTRY

required_extras: Optional[Union[str, List[str]]] = None
Expand All @@ -311,12 +313,12 @@ def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam",
# optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[str, Callable, Tuple[str, Tuple[Any, ...]]]] = None,
# scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None,
preprocess: Optional[Preprocess] = None,
postprocess: Optional[Postprocess] = None,
Expand All @@ -328,8 +330,10 @@ def __init__(
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer = optimizer
self.scheduler = scheduler
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}
if isinstance(self.scheduler, Tuple):
assert isinstance(self.scheduler[0], str)
# self.optimizer_kwargs: Dict[str, Any] = optimizer_kwargs or {}
# self.scheduler_kwargs: Dict[str, Any] = scheduler_kwargs or {}

self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
Expand Down Expand Up @@ -474,11 +478,15 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
return self(batch)

def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:
optimizer = self.optimizer
if not isinstance(self.optimizer, Optimizer):
self.optimizer_kwargs["lr"] = self.learning_rate
optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs)
if self.scheduler:
if isinstance(self.optimizer, str):
self.optimizer = self.optimizers.get(self.optimizer)

model_parameters = filter(lambda p: p.requires_grad, self.parameters())
optimizer: Optimizer = self.optimizer(model_parameters, lr=self.learning_rate)
# if not isinstance(self.optimizer, Optimizer):
# self.optimizer_kwargs["lr"] = self.learning_rate
# optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs)
if self.scheduler is not None:
return [optimizer], [self._instantiate_scheduler(optimizer)]
return optimizer

Expand Down Expand Up @@ -817,23 +825,50 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float]
return round(num_warmup_steps)

def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler:
scheduler = self.scheduler
if isinstance(scheduler, _LRScheduler):
return scheduler
if isinstance(scheduler, str):
scheduler_fn = self.schedulers.get(self.scheduler)
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"),
)
return scheduler_fn(optimizer, num_warmup_steps, num_training_steps)
if issubclass(scheduler, _LRScheduler):
return scheduler(optimizer, **self.scheduler_kwargs)
raise MisconfigurationException(
"scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` "
f"or a built-in scheduler in {self.available_schedulers()}"
)
if isinstance(self.scheduler, Callable):
return self.scheduler(optimizer)

if isinstance(self.scheduler, str):
return self.schedulers.get(self.scheduler)(optimizer)

if not isinstance(self.scheduler, Tuple):
raise TypeError("") # Add message

# By default it is Tuple[str, Tuple[Any, ...]] type now.
scheduler_key = self.scheduler[0]
scheduler_args = self.scheduler[1]

scheduler = self.schedulers.get(scheduler_key, with_metadata=True)
scheduler_fn = scheduler["fn"]
scheduler_metadata = scheduler["metadata"]

if "providers" in scheduler_metadata.keys():
if scheduler_metadata["providers"] == providers._HUGGINGFACE:
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=scheduler_args[0], # num_warmup_steps is the first arg in all schedulers
)
if scheduler["name"] == "constant_schedule_with_warmup":
scheduler_args = (num_warmup_steps, *scheduler_args[1:])
else:
scheduler_args = (num_warmup_steps, num_training_steps, *scheduler_args[1:])

# if isinstance(scheduler, str):
# scheduler_fn = self.schedulers.get(self.scheduler)
# num_training_steps: int = self.get_num_training_steps()
# num_warmup_steps: int = self._compute_warmup(
# num_training_steps=num_training_steps,
# num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"),
# )
# return scheduler_fn(optimizer, num_warmup_steps, num_training_steps)
# # if issubclass(scheduler, _LRScheduler):
# # return scheduler(optimizer, **self.scheduler_kwargs)
# raise MisconfigurationException(
# "scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` "
# f"or a built-in scheduler in {self.available_schedulers()}"
# )
return scheduler_fn(optimizer, *scheduler_args)
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
Expand Down
2 changes: 2 additions & 0 deletions flash/core/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from flash.core.optimizers.lamb import LAMB # noqa: F401
from flash.core.optimizers.lars import LARS # noqa: F401
from flash.core.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR # noqa: F401
from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY # noqa: F401
from flash.core.optimizers.schedulers import _SCHEDULERS_REGISTRY # noqa: F401
12 changes: 12 additions & 0 deletions flash/core/optimizers/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Callable, List

from torch import optim

from flash.core.registry import FlashRegistry

_OPTIMIZERS_REGISTRY = FlashRegistry("optimizer")

_optimizers: List[Callable] = [getattr(optim, n) for n in dir(optim) if ("_" not in n)]

for fn in _optimizers:
_OPTIMIZERS_REGISTRY(fn, name=fn.__name__)
33 changes: 33 additions & 0 deletions flash/core/optimizers/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import inspect
from typing import Callable, List

from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE
from flash.core.utilities.providers import _HUGGINGFACE

_SCHEDULERS_REGISTRY = FlashRegistry("scheduler")

schedulers: List[_LRScheduler] = []
for n in dir(lr_scheduler):
sched = getattr(lr_scheduler, n)

if inspect.isclass(sched) and sched != _LRScheduler and issubclass(sched, _LRScheduler):
schedulers.append(sched)


for scheduler in schedulers:
_SCHEDULERS_REGISTRY(scheduler, name=scheduler.__name__)

if _TRANSFORMERS_AVAILABLE:
from transformers import optimization

functions: List[Callable] = []
for n in dir(optimization):
if "get_" in n and n != "get_scheduler":
functions.append(getattr(optimization, n))

for fn in functions:
_SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:], providers=_HUGGINGFACE)
15 changes: 0 additions & 15 deletions flash/core/schedulers.py

This file was deleted.

6 changes: 3 additions & 3 deletions flash/graph/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __init__(
num_classes: int,
hidden_channels: Union[List[int], int] = 512,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam",
# optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[Callable, Mapping, Sequence, None] = None,
Expand All @@ -132,7 +132,7 @@ def __init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
# optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
Expand Down
6 changes: 3 additions & 3 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(
head: Optional[Union[FunctionType, nn.Module]] = None,
pretrained: Union[bool, str] = True,
loss_fn: Optional[Callable] = None,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam",
# optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
metrics=metrics,
learning_rate=learning_rate,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
# optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
multi_label=multi_label,
Expand Down
9 changes: 4 additions & 5 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Mapping, Optional, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.adapter import AdapterTask
Expand Down Expand Up @@ -61,8 +60,8 @@ def __init__(
backbone: Optional[str] = "resnet18_fpn",
head: Optional[str] = "retinanet",
pretrained: bool = True,
optimizer: Type[Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam",
# optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
learning_rate: float = 5e-3,
Expand All @@ -85,7 +84,7 @@ def __init__(
adapter,
learning_rate=learning_rate,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
# optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
serializer=serializer or Preds(),
Expand Down
8 changes: 4 additions & 4 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -77,8 +77,8 @@ def __init__(
pretraining_transform: str,
backbone: str = "resnet",
pretrained: bool = False,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "SGD",
# optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
learning_rate: float = 1e-3,
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(
super().__init__(
adapter=adapter,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
# optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
learning_rate=learning_rate,
Expand Down
9 changes: 4 additions & 5 deletions flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Mapping, Optional, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.adapter import AdapterTask
Expand Down Expand Up @@ -61,8 +60,8 @@ def __init__(
backbone: Optional[str] = "resnet18_fpn",
head: Optional[str] = "mask_rcnn",
pretrained: bool = True,
optimizer: Type[Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer: Union[Callable[..., torch.optim.Optimizer], str] = "Adam",
# optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
learning_rate: float = 5e-4,
Expand All @@ -85,7 +84,7 @@ def __init__(
adapter,
learning_rate=learning_rate,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
# optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
serializer=serializer or Preds(),
Expand Down
Loading