Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into fix/multilabel
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Apr 21, 2021
2 parents 95c5562 + 79f271e commit 6d3d850
Show file tree
Hide file tree
Showing 21 changed files with 203 additions and 40 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Switch to use `torchmetrics` ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169))

- Better support for `optimizer` and `schedulers` ([#232](https://github.com/PyTorchLightning/lightning-flash/pull/232))



### Fixed

Expand All @@ -28,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `RetinaNet` & `backbones` to `ObjectDetector` Task ([#121](https://github.com/PyTorchLightning/lightning-flash/pull/121))
- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116),
- Added .csv image loading utils ([#116](https://github.com/PyTorchLightning/lightning-flash/pull/116),
[#117](https://github.com/PyTorchLightning/lightning-flash/pull/117),
[#118](https://github.com/PyTorchLightning/lightning-flash/pull/118))

Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ First, finetune:
```python
# import our libraries
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
Expand Down Expand Up @@ -170,7 +170,7 @@ Flash has an Image embedding task to encodes an image into a vector of image fea
<summary>View example</summary>

```python
from flash.core.data import download_data
from flash.data.utils import download_data
from flash.vision import ImageEmbedder

# 1. Download the data
Expand All @@ -197,7 +197,7 @@ Flash has a Summarization task to sum up text from a larger article into a short
```python
# import our libraries
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.text import SummarizationData, SummarizationTask

# 1. Download the data
Expand Down Expand Up @@ -244,7 +244,7 @@ To illustrate, say we want to build a model to predict if a passenger survived o
# import our libraries
from torchmetrics.classification import Accuracy, Precision, Recall
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.tabular import TabularClassifier, TabularData

# 1. Download the data
Expand Down Expand Up @@ -295,7 +295,7 @@ To illustrate, say we want to build a model on a tiny coco dataset.
```python
# import our libraries
import flash
from flash.core.data import download_data
from flash.data.utils import download_data
from flash.vision import ObjectDetectionData, ObjectDetector

# 1. Download the data
Expand Down
2 changes: 1 addition & 1 deletion docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Here are the steps in code
.. code-block:: python
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. download and organize the data
Expand Down
8 changes: 3 additions & 5 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ Predict on a single sample of data

You can pass in a sample of data (image file path, a string of text, etc) to the :func:`~flash.core.model.Task.predict` method.


.. code-block:: python
from flash import Trainer
from flash.core.data import download_data
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
Expand All @@ -37,7 +37,7 @@ Predict on a csv file

.. code-block:: python
from flash.core.data import download_data
from flash.data.utils import download_data
from flash.tabular import TabularClassifier
# 1. Download the data
Expand All @@ -51,5 +51,3 @@ Predict on a csv file
# 3. Generate predictions from a csv file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)
6 changes: 3 additions & 3 deletions docs/source/general/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Some Flash tasks have been pretrained on large data sets. To accelerate your tra
.. code-block:: python
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. download and organize the data
Expand Down Expand Up @@ -48,7 +48,7 @@ Flash tasks supports many advanced training functionalities out-of-the-box, such
# train on 1 GPU
flash.Trainer(gpus=1)
* Training on multiple GPUs

.. code-block:: python
Expand All @@ -60,7 +60,7 @@ Flash tasks supports many advanced training functionalities out-of-the-box, such
# train on gpu 1, 3, 5 (3 gpus total)
flash.Trainer(gpus=[1, 3, 5])
* Using mixed precision training

.. code-block:: python
Expand Down
6 changes: 3 additions & 3 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ For getting started with Deep Learning

Easy to learn
^^^^^^^^^^^^^
If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required!
If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required!

Easy to scale
^^^^^^^^^^^^^
Expand Down Expand Up @@ -70,7 +70,7 @@ You can install flash using pip or conda:
Tasks
=====

Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods.
Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods.

The Flash tasks contain all the relevant information to solve the task at hand- the number of class labels you want to predict, number of columns in your dataset, as well as details on the model architecture used such as loss function, optimizers, etc.

Expand Down Expand Up @@ -137,7 +137,7 @@ Here's an example of finetuning.
.. code-block:: python
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. Download the data
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Use the :class:`~flash.vision.ImageClassifier` pretrained model for inference on
# import our libraries
from flash import Trainer
from flash import download_data
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. Download the data
Expand Down Expand Up @@ -90,7 +90,7 @@ Now all we need is three lines of code to build to train our task!
.. code-block:: python
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. Download the data
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ To tailor this image embedder to your dataset, finetune first.
.. code-block:: python
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageEmbedder
# 1. Download the data
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ To tailor the object detector to your dataset, you would need to have it in `COC
.. code-block:: python
import flash
from flash.core.data import download_data
from flash.data.utils import download_data
from flash.vision import ObjectDetectionData, ObjectDetector
# 1. Download the data
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference/summarization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Or on a given dataset, use :class:`~flash.core.trainer.Trainer` `predict` method
# import our libraries
from flash import Trainer
from flash import download_data
from flash.data.utils import download_data
from flash.text import SummarizationData, SummarizationTask
# 1. Download data
Expand Down Expand Up @@ -104,7 +104,7 @@ All we need is three lines of code to train our model!
# import our libraries
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.text import SummarizationData, SummarizationTask
# 1. Download data
Expand Down
6 changes: 3 additions & 3 deletions docs/source/reference/tabular_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da
.. code-block:: python
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.tabular import TabularClassifier, TabularData
from torchmetrics.classification import Accuracy, Precision, Recall
Expand Down Expand Up @@ -92,7 +92,7 @@ You can make predcitions on a pretrained model, that has already been trained fo
.. code-block:: python
from flash.core.data import download_data
from flash.data.utils import download_data
from flash.tabular import TabularClassifier
# 1. Download the data
Expand All @@ -113,7 +113,7 @@ Or you can finetune your own model and use that for prediction:
.. code-block:: python
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.tabular import TabularClassifier, TabularData
# 1. Load the data
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference/text_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Use the :class:`~flash.text.classification.model.TextClassifier` pretrained mode
from pytorch_lightning import Trainer
from flash import download_data
from flash.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
Expand Down Expand Up @@ -77,7 +77,7 @@ All we need is three lines of code to train our model!
.. code-block:: python
import flash
from flash.core.data import download_data
from flash.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
# 1. Download the data
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference/translation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Or on a given dataset, use :class:`~flash.core.trainer.Trainer` `predict` method
# import our libraries
from flash import Trainer
from flash import download_data
from flash.data.utils import download_data
from flash.text import TranslationData, TranslationTask
# 1. Download data
Expand Down Expand Up @@ -86,7 +86,7 @@ All we need is three lines of code to train our model! By default, we use a `mBA
# import our libraries
import flash
from flash import download_data
from flash.data.utils import download_data
from flash.text import TranslationData, TranslationTask
# 1. Download data
Expand Down
87 changes: 83 additions & 4 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess

Expand Down Expand Up @@ -64,11 +68,16 @@ class Task(LightningModule):
postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task.
"""

schedulers: FlashRegistry = _SCHEDULERS_REGISTRY

def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
preprocess: Preprocess = None,
Expand All @@ -78,7 +87,11 @@ def __init__(
if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer_cls = optimizer
self.optimizer = optimizer
self.scheduler = scheduler
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}

self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
Expand Down Expand Up @@ -168,8 +181,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
batch = torch.stack(batch)
return self(batch)

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:
optimizer = self.optimizer
if not isinstance(self.optimizer, Optimizer):
self.optimizer_kwargs["lr"] = self.learning_rate
optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs)
if self.scheduler:
return [optimizer], [self._instantiate_scheduler(optimizer)]
return optimizer

def configure_finetune_callback(self) -> List[Callback]:
return []
Expand Down Expand Up @@ -323,3 +342,63 @@ def available_models(cls) -> List[str]:
if registry is None:
return []
return registry.available_keys()

@classmethod
def available_schedulers(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None)
if registry is None:
return []
return registry.available_keys()

def get_num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if not getattr(self, "trainer", None):
raise MisconfigurationException("The LightningModule isn't attached to the trainer yet.")
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(self.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.train_dataloader())

num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)

effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs

if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
return self.trainer.max_steps
return max_estimated_steps

def _compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> int:
if not isinstance(num_warmup_steps, float) or (num_warmup_steps > 1 or num_warmup_steps < 0):
raise MisconfigurationException(
"`num_warmup_steps` should be provided as float between 0 and 1 in `scheduler_kwargs`"
)
if isinstance(num_warmup_steps, float):
# Convert float values to percentage of training steps to use as warmup
num_warmup_steps *= num_training_steps
return round(num_warmup_steps)

def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler:
scheduler = self.scheduler
if isinstance(scheduler, _LRScheduler):
return scheduler
if isinstance(scheduler, str):
scheduler_fn = self.schedulers.get(self.scheduler)
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"),
)
return scheduler_fn(optimizer, num_warmup_steps, num_training_steps)
elif issubclass(scheduler, _LRScheduler):
return scheduler(optimizer, **self.scheduler_kwargs)
raise MisconfigurationException(
"scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` "
f"or a built-in scheduler in {self.available_schedulers()}"
)
Loading

0 comments on commit 6d3d850

Please sign in to comment.