Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[PoC] Add MetaLearning support through learn2learn (#737)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
tchaton and ethanwharris authored Sep 20, 2021
1 parent c2095d3 commit 991fdf0
Show file tree
Hide file tree
Showing 30 changed files with 1,229 additions and 73 deletions.
4 changes: 4 additions & 0 deletions .azure-pipelines/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ jobs:
python -m coverage run --source flash -m pytest flash tests/examples/test_scripts.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30
displayName: 'Testing'
- bash: |
bash tests/special_tests.sh
displayName: 'Testing: special'
- bash: |
python -m coverage report
python -m coverage xml
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support `learn2learn` training_strategy for `ImageClassifier` ([#737](https://github.com/PyTorchLightning/lightning-flash/pull/737))

### Changed

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))

### Fixed


## [0.5.0] - 2021-09-07

### Added
Expand Down
28 changes: 22 additions & 6 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from abc import abstractmethod
from typing import Any, Callable, Optional

import torch.jit
from torch import nn
from torch.utils.data import DataLoader, Sampler

Expand Down Expand Up @@ -59,6 +60,10 @@ def test_epoch_end(self, outputs) -> None:
pass


def identity_collate_fn(x):
return x


class AdapterTask(Task):
"""The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter`
and forwards all of the hooks.
Expand All @@ -73,11 +78,12 @@ def __init__(self, adapter: Adapter, **kwargs):

self.adapter = adapter

@torch.jit.unused
@property
def backbone(self) -> nn.Module:
return self.adapter.backbone

def forward(self, x: Any) -> Any:
def forward(self, x: torch.Tensor) -> Any:
return self.adapter.forward(x)

def training_step(self, batch: Any, batch_idx: int) -> Any:
Expand All @@ -104,6 +110,7 @@ def test_epoch_end(self, outputs) -> None:
def process_train_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -113,12 +120,13 @@ def process_train_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_train_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_val_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -128,12 +136,13 @@ def process_val_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_val_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_test_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -143,7 +152,7 @@ def process_test_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_test_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_predict_dataset(
Expand All @@ -152,11 +161,18 @@ def process_predict_dataset(
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = False,
collate_fn: Callable = lambda x: x,
collate_fn: Callable = identity_collate_fn,
shuffle: bool = False,
drop_last: bool = True,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_predict_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn,
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
)
58 changes: 48 additions & 10 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.adapter import AdapterTask
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.data.process import Serializer
from flash.core.model import Task
Expand All @@ -37,7 +38,29 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.
return F.binary_cross_entropy_with_logits(x, y.float())


class ClassificationTask(Task):
class ClassificationMixin:
def _build(
self,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
):
if metrics is None:
metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy()

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy

return metrics, loss_fn

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
return torch.softmax(x, dim=1)


class ClassificationTask(Task, ClassificationMixin):
def __init__(
self,
*args,
Expand All @@ -48,11 +71,9 @@ def __init__(
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:
if metrics is None:
metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy()

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy
metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
loss_fn=loss_fn,
Expand All @@ -61,11 +82,28 @@ def __init__(
**kwargs,
)

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
# we'll assume that the data always comes as `(B, C, ...)`
return torch.softmax(x, dim=1)

class ClassificationAdapterTask(AdapterTask, ClassificationMixin):
def __init__(
self,
*args,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:

metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer or Classes(multi_label=multi_label),
**kwargs,
)


class ClassificationSerializer(Serializer):
Expand Down
3 changes: 3 additions & 0 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _train_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_train_dataset(
train_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
Expand Down Expand Up @@ -326,6 +327,7 @@ def _val_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_val_dataset(
val_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
Expand All @@ -348,6 +350,7 @@ def _test_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
test_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin
if isinstance(dl_args["collate_fn"], _Preprocessor):
dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn

if isinstance(dl_args["dataset"], IterableAutoDataset):
if isinstance(dl_args["dataset"], (IterableAutoDataset, IterableDataset)):
del dl_args["sampler"]

del dl_args["batch_sampler"]
Expand Down
13 changes: 13 additions & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,9 @@ def load_data(

data = make_dataset(data, class_to_idx, extensions=self.extensions)
return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data]
elif dataset is not None:
dataset.num_classes = len(np.unique(data[1]))

return list(
filter(
lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions),
Expand Down Expand Up @@ -622,6 +625,16 @@ class TensorDataSource(SequenceDataSource[torch.Tensor]):
"""The ``TensorDataSource`` is a ``SequenceDataSource`` which expects the input to
:meth:`~flash.core.data.data_source.DataSource.load_data` to be a sequence of ``torch.Tensor`` objects."""

def load_data(
self,
data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]],
dataset: Optional[Any] = None,
) -> Sequence[Mapping[str, Any]]:
# TODO: Bring back the code to work out how many classes there are
if len(data) == 2:
dataset.num_classes = len(torch.unique(torch.tensor(data[1])))
return super().load_data(data, dataset)


class NumpyDataSource(SequenceDataSource[np.ndarray]):
"""The ``NumpyDataSource`` is a ``SequenceDataSource`` which expects the input to
Expand Down
11 changes: 8 additions & 3 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,22 @@ def default_transforms() -> Optional[Dict[str, Callable]]:
"""
return None

def _apply_sample_transform(self, sample: Any) -> Any:
if isinstance(sample, list):
return [self.current_transform(s) for s in sample]
return self.current_transform(sample)

def pre_tensor_transform(self, sample: Any) -> Any:
"""Transforms to apply on a single object."""
return self.current_transform(sample)
return self._apply_sample_transform(sample)

def to_tensor_transform(self, sample: Any) -> Tensor:
"""Transforms to convert single object to a tensor."""
return self.current_transform(sample)
return self._apply_sample_transform(sample)

def post_tensor_transform(self, sample: Tensor) -> Tensor:
"""Transforms to apply on a tensor."""
return self.current_transform(sample)
return self._apply_sample_transform(sample)

def per_batch_transform(self, batch: Any) -> Any:
"""Transforms to apply to a whole batch (if possible use this for efficiency).
Expand Down
2 changes: 2 additions & 0 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
This function removes that dimension and then
applies ``torch.utils.data._utils.collate.default_collate``.
"""
if len(samples) == 1 and isinstance(samples[0], list):
samples = samples[0]
for sample in samples:
for key in sample.keys():
if torch.is_tensor(sample[key]) and sample[key].ndim == 4:
Expand Down
4 changes: 4 additions & 0 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from torch.utils.data import DataLoader, Sampler

import flash
from flash.core.adapter import Adapter
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_source import DefaultDataKeys
Expand Down Expand Up @@ -91,6 +92,7 @@ def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] =
def process_train_dataset(
self,
dataset: BaseAutoDataset,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -114,6 +116,7 @@ def process_train_dataset(
def process_val_dataset(
self,
dataset: BaseAutoDataset,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -137,6 +140,7 @@ def process_val_dataset(
def process_test_dataset(
self,
dataset: BaseAutoDataset,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand Down
Loading

0 comments on commit 991fdf0

Please sign in to comment.