Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add optimizer_kwargs, scheduler, and scheduler_kwargs parameters to T…
Browse files Browse the repository at this point in the history
…ask __init__ methods. (#730)
  • Loading branch information
karthikrangasai authored Sep 6, 2021
1 parent 7cf77c7 commit cf86275
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 6 deletions.
7 changes: 7 additions & 0 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler

from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES
from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding
Expand All @@ -42,6 +43,9 @@ def __init__(
backbone: str = "facebook/wav2vec2-base-960h",
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
learning_rate: float = 1e-5,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
Expand All @@ -58,6 +62,9 @@ def __init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
learning_rate=learning_rate,
serializer=serializer,
)
Expand Down
12 changes: 11 additions & 1 deletion flash/graph/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.classification import ClassificationTask
from flash.core.utilities.imports import _GRAPH_AVAILABLE
Expand Down Expand Up @@ -91,6 +92,9 @@ class GraphClassifier(ClassificationTask):
hidden_channels: Hidden dimension sizes.
loss_fn: Loss function for training, defaults to cross entropy.
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `1e-3`
model: GraphNN used, defaults to BaseGraphModel.
Expand All @@ -106,6 +110,9 @@ def __init__(
hidden_channels: Union[List[int], int] = 512,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[Callable, Mapping, Sequence, None] = None,
learning_rate: float = 1e-3,
model: torch.nn.Module = None,
Expand All @@ -125,6 +132,9 @@ def __init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
learning_rate=learning_rate,
)
Expand Down
3 changes: 3 additions & 0 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def fn_resnet(pretrained: bool = True):
which loads the default supervised pretrained weights.
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
package, a custom metric inheriting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
Expand Down
10 changes: 10 additions & 0 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.adapter import AdapterTask
from flash.core.data.process import Serializer
Expand All @@ -41,6 +42,9 @@ class ObjectDetector(AdapterTask):
metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger.
Changing this argument currently has no effect.
optimizer: The optimizer to use for training. Can either be the actual class or the class name.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
pretrained: Whether the model from torchvision should be loaded with it's pretrained weights.
Has no effect for custom models.
learning_rate: The learning rate to use for training
Expand All @@ -58,6 +62,9 @@ def __init__(
head: Optional[str] = "retinanet",
pretrained: bool = True,
optimizer: Type[Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
learning_rate: float = 5e-3,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs: Any,
Expand All @@ -78,6 +85,9 @@ def __init__(
adapter,
learning_rate=learning_rate,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
serializer=serializer or Preds(),
)

Expand Down
12 changes: 11 additions & 1 deletion flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from pytorch_lightning.utilities import rank_zero_warn
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import _LRScheduler
from torchmetrics import Accuracy, Metric

from flash.core.data.data_source import DefaultDataKeys
Expand All @@ -42,6 +43,9 @@ class ImageEmbedder(Task):
pretrained: Use a pretrained backbone, defaults to ``True``.
loss_fn: Loss function for training and finetuning, defaults to :func:`torch.nn.functional.cross_entropy`
optimizer: Optimizer to use for training and finetuning, defaults to :class:`torch.optim.SGD`.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
Expand All @@ -61,6 +65,9 @@ def __init__(
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = (Accuracy()),
learning_rate: float = 1e-3,
pooling_fn: Callable = torch.max,
Expand All @@ -69,6 +76,9 @@ def __init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
learning_rate=learning_rate,
preprocess=ImageClassificationPreprocess(),
Expand Down
10 changes: 10 additions & 0 deletions flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.adapter import AdapterTask
from flash.core.data.process import Serializer
Expand All @@ -41,6 +42,9 @@ class InstanceSegmentation(AdapterTask):
metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger.
Changing this argument currently has no effect.
optimizer: The optimizer to use for training. Can either be the actual class or the class name.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
pretrained: Whether the model from torchvision should be loaded with it's pretrained weights.
Has no effect for custom models.
learning_rate: The learning rate to use for training
Expand All @@ -58,6 +62,9 @@ def __init__(
head: Optional[str] = "mask_rcnn",
pretrained: bool = True,
optimizer: Type[Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
learning_rate: float = 5e-4,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs: Any,
Expand All @@ -78,6 +85,9 @@ def __init__(
adapter,
learning_rate=learning_rate,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
serializer=serializer or Preds(),
)

Expand Down
10 changes: 10 additions & 0 deletions flash/image/keypoint_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.adapter import AdapterTask
from flash.core.data.process import Serializer
Expand All @@ -41,6 +42,9 @@ class KeypointDetector(AdapterTask):
metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger.
Changing this argument currently has no effect.
optimizer: The optimizer to use for training. Can either be the actual class or the class name.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
pretrained: Whether the model from torchvision should be loaded with it's pretrained weights.
Has no effect for custom models.
learning_rate: The learning rate to use for training
Expand All @@ -59,6 +63,9 @@ def __init__(
head: Optional[str] = "keypoint_rcnn",
pretrained: bool = True,
optimizer: Type[Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
learning_rate: float = 5e-4,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs: Any,
Expand All @@ -80,6 +87,9 @@ def __init__(
adapter,
learning_rate=learning_rate,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
serializer=serializer or Preds(),
)

Expand Down
10 changes: 10 additions & 0 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import _LRScheduler
from torchmetrics import IoU, Metric

from flash.core.classification import ClassificationTask
Expand Down Expand Up @@ -53,6 +54,9 @@ class SemanticSegmentation(ClassificationTask):
pretrained: Use a pretrained backbone.
loss_fn: Loss function for training.
optimizer: Optimizer to use for training.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
Expand Down Expand Up @@ -80,6 +84,9 @@ def __init__(
pretrained: Union[bool, str] = True,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
learning_rate: float = 1e-3,
multi_label: bool = False,
Expand All @@ -100,6 +107,9 @@ def __init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer or SegmentationLabels(),
Expand Down
7 changes: 7 additions & 0 deletions flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch.nn import functional as F
from torch.optim.lr_scheduler import _LRScheduler
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Probabilities
Expand Down Expand Up @@ -56,6 +57,9 @@ def __init__(
embedding_sizes: List[Tuple[int, int]] = None,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
learning_rate: float = 1e-2,
multi_label: bool = False,
Expand All @@ -78,6 +82,9 @@ def __init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
Expand Down
10 changes: 10 additions & 0 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
from pytorch_lightning import Callback
from torch.optim.lr_scheduler import _LRScheduler
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Labels
Expand All @@ -38,6 +39,9 @@ class TextClassifier(ClassificationTask):
num_classes: Number of classes to classify.
backbone: A model to use to compute text features can be any BERT model from HuggingFace/transformersimage .
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
Expand All @@ -56,6 +60,9 @@ def __init__(
backbone: str = "prajjwal1/bert-medium",
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
learning_rate: float = 1e-2,
multi_label: bool = False,
Expand All @@ -75,6 +82,9 @@ def __init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
Expand Down
19 changes: 17 additions & 2 deletions flash/text/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import collections
import os
import warnings
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import numpy as np
import torch
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_info
from torch import Tensor
from torch.optim.lr_scheduler import _LRScheduler
from torchmetrics import Metric

from flash.core.data.data_source import DefaultDataKeys
Expand Down Expand Up @@ -58,6 +59,9 @@ class QuestionAnsweringTask(Task):
backbone: backbone model to use for the task.
loss_fn: Loss function for training.
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric.
Changing this argument currently has no effect.
learning_rate: Learning rate to use for training, defaults to `3e-4`
Expand All @@ -80,6 +84,9 @@ def __init__(
backbone: str = "distilbert-base-uncased",
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
enable_ort: bool = False,
Expand All @@ -95,7 +102,15 @@ def __init__(
warnings.simplefilter("ignore")
# set os environ variable for multiprocesses
os.environ["PYTHONWARNINGS"] = "ignore"
super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate)
super().__init__(
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
learning_rate=learning_rate,
)
self.model = AutoModelForQuestionAnswering.from_pretrained(backbone)
self.enable_ort = enable_ort
self.n_best_size = n_best_size
Expand Down
Loading

0 comments on commit cf86275

Please sign in to comment.