Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Update lightning version to v1.2 #133

Merged
merged 7 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class NoFreeze(BaseFinetuning):
def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
pass

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand All @@ -42,7 +42,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo

FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback.

Override ``finetunning_function`` to put your unfreeze logic.
Override ``finetune_function`` to put your unfreeze logic.

Args:
attr_names: Name(s) of the module attributes of the model to be frozen.
Expand All @@ -62,15 +62,15 @@ def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bo
attr = getattr(pl_module, attr_name, None)
if attr is None or not isinstance(attr, nn.Module):
MisconfigurationException(f"Your model must have a {attr} attribute")
self.freeze(module=attr, train_bn=train_bn)
self.freeze(modules=attr, train_bn=train_bn)

def finetunning_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
pass


class Freeze(FlashBaseFinetuning):

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand All @@ -86,7 +86,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo
super().__init__(attr_names, train_bn)
self.unfreeze_epoch = unfreeze_epoch

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(

super().__init__(attr_names, train_bn)

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/detection/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def __init__(self, train_bn: bool = True):

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
model = pl_module.model
self.freeze(module=model.backbone, train_bn=self.train_bn)
self.freeze(modules=model.backbone, train_bn=self.train_bn)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytorch-lightning==1.2.0rc0 # todo: we shall align with real 1.2
pytorch-lightning>=1.2.5
torch>=1.7 # TODO: regenerate weights with lewer PT version
PyYAML>=5.1
Pillow>=7.2
Expand Down
2 changes: 1 addition & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import urllib
import urllib.request

# TorchVision hotfix https://github.com/pytorch/vision/issues/1938
opener = urllib.request.build_opener()
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def test_classification_task_predict_folder_path(tmpdir):
assert len(predictions) == 2


def test_classificationtask_trainer_predict(tmpdir):
@pytest.mark.skip("Requires DataPipeline update") # TODO
def test_classification_task_trainer_predict(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
task = ClassificationTask(model)
ds = DummyDataset()
Expand Down