Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit 4729518

Browse files
authored
Merge branch 'master' into docs/fix-colab-link
2 parents aa40285 + f12f55c commit 4729518

File tree

18 files changed

+255
-180
lines changed

18 files changed

+255
-180
lines changed

.azure-pipelines/gpu-tests.yml

+4-10
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,12 @@
33
# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more:
44
# https://docs.microsoft.com/azure/devops/pipelines/languages/python
55

6-
trigger:
7-
tags:
8-
include:
9-
- '*'
6+
schedules:
7+
- cron: "0 0 * * *"
8+
displayName: Daily midnight build
109
branches:
1110
include:
12-
- master
13-
- release/*
14-
- refs/tags/*
15-
pr:
16-
- master
17-
- release/*
11+
- master
1812

1913
jobs:
2014
- job: pytest

docs/source/general/data.rst

+119-116
Large diffs are not rendered by default.

flash/core/model.py

+24
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414
import functools
1515
import inspect
16+
from copy import deepcopy
1617
from importlib import import_module
1718
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
1819

20+
import pytorch_lightning as pl
1921
import torch
2022
import torchmetrics
2123
from pytorch_lightning import LightningModule
@@ -26,6 +28,7 @@
2628
from torch.optim.lr_scheduler import _LRScheduler
2729
from torch.optim.optimizer import Optimizer
2830

31+
import flash
2932
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
3033
from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources
3134
from flash.core.data.process import Postprocess, Preprocess, Serializer, SerializerMapping
@@ -34,6 +37,22 @@
3437
from flash.core.utilities.apply_func import get_callable_dict
3538

3639

40+
class BencharmkConvergenceCI(Callback):
41+
42+
def __init__(self):
43+
pl.seed_everything(42)
44+
self.history = []
45+
46+
def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
47+
self.history.append(deepcopy(trainer.callback_metrics))
48+
if trainer.current_epoch == trainer.max_epochs - 1:
49+
fn = getattr(pl_module, "_ci_benchmark_fn", None)
50+
if fn:
51+
fn(self.history)
52+
if trainer.is_global_zero:
53+
print("Benchmark Successfull !")
54+
55+
3756
def predict_context(func: Callable) -> Callable:
3857
"""
3958
This decorator is used as context manager
@@ -516,3 +535,8 @@ def _load_from_state_dict(
516535
super()._load_from_state_dict(
517536
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
518537
)
538+
539+
def configure_callbacks(self):
540+
# used only for CI
541+
if flash._IS_TESTING and torch.cuda.is_available():
542+
return [BencharmkConvergenceCI()]

flash/core/trainer.py

+5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def __init__(self, *args, **kwargs):
5656
if flash._IS_TESTING:
5757
if torch.cuda.is_available():
5858
kwargs["gpus"] = 1
59+
kwargs["max_epochs"] = 3
60+
kwargs["limit_train_batches"] = 1.0
61+
kwargs["limit_val_batches"] = 1.0
62+
kwargs["limit_test_batches"] = 1.0
63+
kwargs["fast_dev_run"] = False
5964
else:
6065
kwargs["fast_dev_run"] = True
6166
super().__init__(*args, **kwargs)

flash/image/classification/model.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from types import FunctionType
15-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union
15+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
1616

17+
import pytorch_lightning as pl
1718
import torch
1819
import torchmetrics
20+
from pytorch_lightning.callbacks.base import Callback
1921
from torch import nn
2022
from torch.optim.lr_scheduler import _LRScheduler
2123

24+
import flash
2225
from flash.core.classification import ClassificationTask
2326
from flash.core.data.data_source import DefaultDataKeys
2427
from flash.core.data.process import Serializer
@@ -128,3 +131,12 @@ def forward(self, x) -> torch.Tensor:
128131
if x.dim() == 4:
129132
x = x.mean(-1).mean(-1)
130133
return self.head(x)
134+
135+
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
136+
"""
137+
This function is used only for debugging usage with CI
138+
"""
139+
if self.hparams.multi_label:
140+
assert history[-1]["val_f1"] > 0.45
141+
else:
142+
assert history[-1]["val_accuracy"] > 0.90

flash/image/detection/model.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union
14+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
1515

1616
import torch
1717
from torch import nn, tensor
@@ -88,7 +88,7 @@ def __init__(
8888
anchor_generator: Optional[Type['AnchorGenerator']] = None,
8989
loss=None,
9090
metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None,
91-
optimizer: Type[Optimizer] = torch.optim.Adam,
91+
optimizer: Type[Optimizer] = torch.optim.AdamW,
9292
learning_rate: float = 1e-3,
9393
**kwargs: Any,
9494
):
@@ -180,28 +180,28 @@ def validation_step(self, batch, batch_idx):
180180
# fasterrcnn takes only images for eval() mode
181181
outs = self.model(images)
182182
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
183-
return {"val_iou": iou}
183+
self.log("val_iou", iou)
184184

185-
def validation_epoch_end(self, outs):
186-
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
187-
logs = {"val_iou": avg_iou}
188-
return {"avg_val_iou": avg_iou, "log": logs}
185+
def on_validation_end(self) -> None:
186+
return super().on_validation_end()
189187

190188
def test_step(self, batch, batch_idx):
191189
images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]
192190
# fasterrcnn takes only images for eval() mode
193191
outs = self.model(images)
194192
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
195-
return {"test_iou": iou}
196-
197-
def test_epoch_end(self, outs):
198-
avg_iou = torch.stack([o["test_iou"] for o in outs]).mean()
199-
logs = {"test_iou": avg_iou}
200-
return {"avg_test_iou": avg_iou, "log": logs}
193+
self.log("test_iou", iou)
201194

202195
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
203196
images = batch[DefaultDataKeys.INPUT]
204197
return self.model(images)
205198

206199
def configure_finetune_callback(self):
207200
return [ObjectDetectionFineTuning(train_bn=True)]
201+
202+
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
203+
"""
204+
This function is used only for debugging usage with CI
205+
"""
206+
# todo (tchaton) Improve convergence
207+
# history[-1]["val_iou"]

flash/image/segmentation/model.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union
14+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
1515

1616
import torch
1717
from torch import nn
@@ -148,3 +148,9 @@ def forward(self, x) -> torch.Tensor:
148148
raise NotImplementedError(f"Unsupported output type: {type(out)}")
149149

150150
return out
151+
152+
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
153+
"""
154+
This function is used only for debugging usage with CI
155+
"""
156+
assert history[-1]["val_iou"] > 0.2

flash/image/segmentation/transforms.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,10 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
3838
"post_tensor_transform": nn.Sequential(
3939
ApplyToKeys(
4040
[DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
41-
KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest')),
41+
KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='bilinear')),
4242
),
4343
),
4444
"collate": Compose([kornia_collate, ApplyToKeys(DefaultDataKeys.TARGET, prepare_target)]),
45-
"per_batch_transform_on_device": ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)),
4645
}
4746

4847

@@ -53,12 +52,8 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]
5352
"post_tensor_transform": nn.Sequential(
5453
ApplyToKeys(
5554
[DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
56-
KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.75)),
55+
KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)),
5756
),
5857
),
59-
"per_batch_transform_on_device": ApplyToKeys(
60-
DefaultDataKeys.INPUT,
61-
K.augmentation.ColorJitter(0.4, p=0.5),
62-
),
6358
}
6459
)

flash/tabular/classification/model.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, List, Mapping, Optional, Tuple, Type, Union
14+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union
1515

1616
import torch
1717
from torch.nn import functional as F
@@ -51,7 +51,7 @@ def __init__(
5151
loss_fn: Callable = F.cross_entropy,
5252
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
5353
metrics: List[Metric] = None,
54-
learning_rate: float = 1e-3,
54+
learning_rate: float = 1e-2,
5555
multi_label: bool = False,
5656
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
5757
**tabnet_kwargs,
@@ -106,3 +106,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
106106
def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier':
107107
model = cls(datamodule.num_features, datamodule.num_classes, datamodule.emb_sizes, **kwargs)
108108
return model
109+
110+
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
111+
"""
112+
This function is used only for debugging usage with CI
113+
"""
114+
assert history[-1]["val_accuracy"] > 0.75

flash/text/classification/data.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
from functools import partial
16-
from logging import logMultiprocessing
1715
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
1816

17+
import torch
1918
from torch import Tensor
2019

2120
import flash
@@ -69,7 +68,6 @@ def load_data(
6968
data: Tuple[str, Union[str, List[str]], Union[str, List[str]]],
7069
dataset: Optional[Any] = None,
7170
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
72-
use_full: bool = True,
7371
) -> Union[Sequence[Mapping[str, Any]]]:
7472
csv_file, input, target = data
7573

@@ -79,11 +77,14 @@ def load_data(
7977
data_files[stage] = str(csv_file)
8078

8179
# FLASH_TESTING is set in the CI to run faster.
82-
if flash._IS_TESTING and not use_full:
83-
# used for debugging. Avoid processing the entire dataset # noqa E265
84-
dataset_dict = DatasetDict({
85-
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
86-
})
80+
# FLASH_TESTING is set in the CI to run faster.
81+
if flash._IS_TESTING and not torch.cuda.is_available():
82+
try:
83+
dataset_dict = DatasetDict({
84+
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
85+
})
86+
except Exception:
87+
dataset_dict = load_dataset(self.filetype, data_files=data_files)
8788
else:
8889
dataset_dict = load_dataset(self.filetype, data_files=data_files)
8990

flash/text/classification/model.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
import warnings
16-
from typing import Callable, Mapping, Optional, Sequence, Type, Union
16+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
1717

1818
import torch
1919

@@ -42,10 +42,10 @@ class TextClassifier(ClassificationTask):
4242
def __init__(
4343
self,
4444
num_classes: int,
45-
backbone: str = "prajjwal1/bert-tiny",
45+
backbone: str = "prajjwal1/bert-medium",
4646
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
4747
metrics: Union[Callable, Mapping, Sequence, None] = None,
48-
learning_rate: float = 1e-3,
48+
learning_rate: float = 1e-2,
4949
multi_label: bool = False,
5050
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
5151
):
@@ -90,3 +90,9 @@ def step(self, batch, batch_idx) -> dict:
9090
probs = torch.softmax(logits, 1)
9191
output["logs"] = {name: metric(probs, batch["labels"]) for name, metric in self.metrics.items()}
9292
return output
93+
94+
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
95+
"""
96+
This function is used only for debugging usage with CI
97+
"""
98+
assert history[-1]["val_accuracy"] > 0.730

flash/text/seq2seq/core/data.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -88,34 +88,30 @@ def __init__(
8888
self.filetype = filetype
8989

9090
def load_data(
91-
self,
92-
data: Any,
93-
use_full: bool = True,
94-
columns: List[str] = ["input_ids", "attention_mask", "labels"]
91+
self, data: Any, columns: List[str] = ["input_ids", "attention_mask", "labels"]
9592
) -> 'datasets.Dataset':
9693
file, input, target = data
9794
data_files = {}
9895
stage = self._running_stage.value
9996
data_files[stage] = str(file)
10097

10198
# FLASH_TESTING is set in the CI to run faster.
102-
if use_full and not flash._IS_TESTING:
103-
dataset_dict = load_dataset(self.filetype, data_files=data_files)
104-
else:
105-
# used for debugging. Avoid processing the entire dataset # noqa E265
99+
if flash._IS_TESTING and not torch.cuda.is_available():
106100
try:
107101
dataset_dict = DatasetDict({
108102
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
109103
})
110-
except AssertionError:
104+
except Exception:
111105
dataset_dict = load_dataset(self.filetype, data_files=data_files)
106+
else:
107+
dataset_dict = load_dataset(self.filetype, data_files=data_files)
112108

113109
dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True)
114110
dataset_dict.set_format(columns=columns)
115111
return dataset_dict[stage]
116112

117113
def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]:
118-
return self.load_data(data, use_full=True, columns=["input_ids", "attention_mask"])
114+
return self.load_data(data, columns=["input_ids", "attention_mask"])
119115

120116

121117
class Seq2SeqCSVDataSource(Seq2SeqFileDataSource):

flash/text/seq2seq/summarization/model.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Callable, Dict, Mapping, Optional, Sequence, Type, Union
14+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
1515

1616
import pytorch_lightning as pl
1717
import torch
@@ -70,3 +70,9 @@ def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: s
7070
tgt_lns = self.tokenize_labels(batch["labels"])
7171
result = self.rouge(self._postprocess.uncollate(generated_tokens), tgt_lns)
7272
self.log_dict(result, on_step=False, on_epoch=True)
73+
74+
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
75+
"""
76+
This function is used only for debugging usage with CI
77+
"""
78+
assert history[-1]["val_f1"] > 0.45

0 commit comments

Comments
 (0)