Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add option for nested tasks #575

Merged
merged 7 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for Semantic Segmentation backbones and heads from `segmentation-models.pytorch` ([#562](https://github.com/PyTorchLightning/lightning-flash/pull/562))

- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
12 changes: 12 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def __init__(
self.deserializer = deserializer
self.serializer = serializer

self._children = []

def __setattr__(self, key, value):
if isinstance(value, LightningModule):
self._children.append(key)
patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"]
if isinstance(value, pl.Trainer) or key in patched_attributes:
if hasattr(self, "_children"):
for child in self._children:
setattr(getattr(self, child), key, value)
super().__setattr__(key, value)

def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
"""
The training/validation/test step. Override for custom behavior.
Expand Down
41 changes: 41 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,32 @@ def forward(self, x):
return x * self.zeros + self.zero_one


class Parent(ClassificationTask):

def __init__(self, child):
super().__init__()

self.child = child

def training_step(self, batch, batch_idx):
return self.child.training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
return self.child.validation_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
return self.child.test_step(batch, batch_idx)

def forward(self, x):
return self.child(x)


class GrandParent(Parent):

def __init__(self, child):
super().__init__(Parent(child))


# ================================


Expand All @@ -113,6 +139,21 @@ def test_classificationtask_train(tmpdir: str, metrics: Any):
assert "test_nll_loss" in result[0]


@pytest.mark.parametrize("task", [Parent, GrandParent])
def test_nested_tasks(tmpdir, task):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
val_dl = torch.utils.data.DataLoader(DummyDataset())
child_task = ClassificationTask(model, loss_fn=F.nll_loss)

parent_task = task(child_task)

trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(parent_task, train_dl, val_dl)
result = trainer.test(parent_task, val_dl)
assert "test_nll_loss" in result[0]


def test_classificationtask_task_predict():
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
task = ClassificationTask(model, preprocess=DefaultPreprocess())
Expand Down