Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[PoC] Add MetaLearning support through learn2learn #737

Merged
merged 63 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
4272598
update
tchaton Sep 6, 2021
73ec02e
update
tchaton Sep 7, 2021
986dfe0
update
tchaton Sep 7, 2021
17267a8
update
tchaton Sep 7, 2021
a3364a5
Merge branch 'master' into learn2learn
tchaton Sep 7, 2021
b33beb7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2021
1bc4298
wip
tchaton Sep 8, 2021
70ff518
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 8, 2021
1af0544
update
tchaton Sep 8, 2021
084721c
Merge branch 'master' into learn2learn
tchaton Sep 8, 2021
c9c3a21
update imports
tchaton Sep 8, 2021
3b762f2
simplification
tchaton Sep 8, 2021
12d2668
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 8, 2021
529d462
wip
tchaton Sep 8, 2021
5e202c4
update
tchaton Sep 8, 2021
73e4aa8
Fix JIT issues
ethanwharris Sep 8, 2021
593e0c9
Fix test
ethanwharris Sep 8, 2021
004e399
add ddp test
tchaton Sep 8, 2021
a65d23f
update
tchaton Sep 8, 2021
84bed01
test
tchaton Sep 8, 2021
3b6d919
update
tchaton Sep 8, 2021
38d5eee
add persistant workers
tchaton Sep 8, 2021
fffbaa6
update
tchaton Sep 8, 2021
2d819ca
update changelog
tchaton Sep 8, 2021
7e51199
update
tchaton Sep 8, 2021
2580063
Update flash_examples/image_classification.py
tchaton Sep 8, 2021
eaf8dfc
Update flash_examples/image_classification_meta_learning.py
tchaton Sep 8, 2021
9097697
repair the sampling
tchaton Sep 10, 2021
62476d5
update
tchaton Sep 10, 2021
3991f12
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 10, 2021
4a550e4
update
tchaton Sep 10, 2021
c5b5940
update
tchaton Sep 10, 2021
40d0dca
update
tchaton Sep 10, 2021
dd0cb79
update
tchaton Sep 10, 2021
7dc1d34
update
tchaton Sep 10, 2021
482f576
update
tchaton Sep 10, 2021
72e1cb7
update
tchaton Sep 10, 2021
afc6219
update
tchaton Sep 11, 2021
cd63701
update
tchaton Sep 11, 2021
8dff389
update
tchaton Sep 12, 2021
ce54995
update
tchaton Sep 12, 2021
d0bd09c
update
tchaton Sep 12, 2021
dad26cc
Merge branch 'master' into learn2learn
tchaton Sep 13, 2021
a802ec7
Update CHANGELOG.md
ethanwharris Sep 13, 2021
43d201f
Update CHANGELOG.md
ethanwharris Sep 13, 2021
47dace8
update
tchaton Sep 14, 2021
baa1fcf
update example
tchaton Sep 20, 2021
8280807
Merge branch 'master' into learn2learn
tchaton Sep 20, 2021
01b2049
update
tchaton Sep 20, 2021
6928d10
update on comments
tchaton Sep 20, 2021
40c827d
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 20, 2021
c040b5c
update
tchaton Sep 20, 2021
e95d565
update
tchaton Sep 20, 2021
a903cd2
remove typing
tchaton Sep 20, 2021
cd91b53
update
tchaton Sep 20, 2021
7376197
Update gpu-tests.yml
ethanwharris Sep 20, 2021
d2e22ec
update
tchaton Sep 20, 2021
9b2b2bf
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 20, 2021
1c8660f
Apply suggestions from code review
ethanwharris Sep 20, 2021
3b28379
resolve test
tchaton Sep 20, 2021
65d3b06
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 20, 2021
a06546a
Merge branch 'master' into learn2learn
mergify[bot] Sep 20, 2021
fd2cce5
update
tchaton Sep 20, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.adapter import AdapterTask
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.data.process import Serializer
from flash.core.model import Task
Expand Down Expand Up @@ -68,6 +69,37 @@ def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
return torch.softmax(x, dim=1)


class ClassificationAdapterTask(AdapterTask):
def __init__(
self,
*args,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:
if metrics is None:
metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy()

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy
super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer or Classes(multi_label=multi_label),
**kwargs,
)

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
# we'll assume that the data always comes as `(B, C, ...)`
return torch.softmax(x, dim=1)


class ClassificationSerializer(Serializer):
"""A base class for classification serializers.

Expand Down
3 changes: 3 additions & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,9 @@ def load_data(

data = make_dataset(data, class_to_idx, extensions=self.extensions)
return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data]
elif dataset is not None:
dataset.num_classes = len(np.unique(data[1]))

return list(
filter(
lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions),
Expand Down
6 changes: 6 additions & 0 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,20 @@ def default_transforms() -> Optional[Dict[str, Callable]]:

def pre_tensor_transform(self, sample: Any) -> Any:
"""Transforms to apply on a single object."""
if isinstance(sample, list):
return [self.current_transform(s) for s in sample]
return self.current_transform(sample)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def to_tensor_transform(self, sample: Any) -> Tensor:
"""Transforms to convert single object to a tensor."""
if isinstance(sample, list):
return [self.current_transform(s) for s in sample]
return self.current_transform(sample)

def post_tensor_transform(self, sample: Tensor) -> Tensor:
"""Transforms to apply on a tensor."""
if isinstance(sample, list):
return [self.current_transform(s) for s in sample]
return self.current_transform(sample)

def per_batch_transform(self, batch: Any) -> Any:
Expand Down
2 changes: 2 additions & 0 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
This function removes that dimension and then
applies ``torch.utils.data._utils.collate.default_collate``.
"""
if len(samples) == 1 and isinstance(samples[0], list):
samples = samples[0]
for sample in samples:
for key in sample.keys():
if torch.is_tensor(sample[key]) and sample[key].ndim == 4:
Expand Down
2 changes: 1 addition & 1 deletion flash/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get(
"""
matches = [e for e in self.functions if key == e["name"]]
if not matches:
raise KeyError(f"Key: {key} is not in {type(self).__name__}")
raise KeyError(f"Key: {key} is not in {type(self).__name__}. Available keys: {self.available_keys()}")

if metadata:
matches = [m for m in matches if metadata.items() <= m["metadata"].items()]
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _compare_version(package: str, op, version) -> bool:
_DATASETS_AVAILABLE = _module_available("datasets")
_ICEVISION_AVAILABLE = _module_available("icevision")
_ICEDATA_AVAILABLE = _module_available("icedata")
_LEARN2LEARN_AVAILABLE = _module_available("learn2learn")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")
_VISSL_AVAILABLE = _module_available("vissl") and _module_available("classy_vision")

Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __str__(self):
_SEGMENTATION_MODELS = Provider(
"qubvel/segmentation_models.pytorch", "https://github.com/qubvel/segmentation_models.pytorch"
)
_LEARN2LEARN = Provider("learnables/learn2learn", "https://github.com/learnables/learn2learn")
_PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche")
_HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers")
_FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq")
Expand Down
Loading