Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add option for nested tasks #575

Merged
merged 7 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
12 changes: 12 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def __init__(
self.deserializer = deserializer
self.serializer = serializer

self._children = []

def __setattr__(self, key, value):
if isinstance(value, LightningModule):
self._children.append(key)
patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"]
if isinstance(value, pl.Trainer) or key in patched_attributes:
if hasattr(self, "_children"):
for child in self._children:
setattr(getattr(self, child), key, value)
super().__setattr__(key, value)

def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
"""
The training/validation/test step. Override for custom behavior.
Expand Down
33 changes: 33 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,39 @@ def test_classificationtask_train(tmpdir: str, metrics: Any):
assert "test_nll_loss" in result[0]


def test_nested_task(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
val_dl = torch.utils.data.DataLoader(DummyDataset())
child_task = ClassificationTask(model, loss_fn=F.nll_loss)

class Parent(ClassificationTask):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, child):
super().__init__()

self.child = child

def training_step(self, batch, batch_idx):
return self.child.training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
return self.child.validation_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
return self.child.test_step(batch, batch_idx)

def forward(self, x):
return self.child(x)

parent_task = Parent(child_task)

trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(parent_task, train_dl, val_dl)
result = trainer.test(parent_task, val_dl)
assert "test_nll_loss" in result[0]


def test_classificationtask_task_predict():
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
task = ClassificationTask(model, preprocess=DefaultPreprocess())
Expand Down