Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Aug 17, 2021
1 parent b18a23a commit 65791b8
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torch.utils.data import DataLoader

import flash
from flash.core.adapter import Adapter
from flash.core.classification import ClassificationTask
from flash.core.data.process import DefaultPreprocess, Postprocess
from flash.core.utilities.imports import _PIL_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE
Expand Down Expand Up @@ -125,6 +126,30 @@ def __init__(self, child):
super().__init__(Parent(child))


class BasicAdapter(Adapter):
def __init__(self, child):
super().__init__()

self.child = child

def training_step(self, batch, batch_idx):
return self.child.training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
return self.child.validation_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
return self.child.test_step(batch, batch_idx)

def forward(self, x):
return self.child(x)


class AdapterParent(Parent):
def __init__(self, child):
super().__init__(BasicAdapter(child))


# ================================


Expand All @@ -140,7 +165,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any):
assert "test_nll_loss" in result[0]


@pytest.mark.parametrize("task", [Parent, GrandParent])
@pytest.mark.parametrize("task", [Parent, GrandParent, AdapterParent])
def test_nested_tasks(tmpdir, task):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
Expand Down Expand Up @@ -263,7 +288,7 @@ def test_available_backbones():
class Foo(ImageClassifier):
backbones = None

assert Foo.available_backbones() == []
assert Foo.available_backbones() == {}


def test_optimization(tmpdir):
Expand Down Expand Up @@ -313,7 +338,7 @@ def test_optimization(tmpdir):
scheduler_kwargs={"num_warmup_steps": 0.1},
loss_fn=F.nll_loss,
)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=2)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count())
ds = DummyDataset()
trainer.fit(task, train_dataloader=DataLoader(ds))
optimizer, scheduler = task.configure_optimizers()
Expand All @@ -334,5 +359,5 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5)

task = ClassificationTask(model)
trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy())
trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count())
trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))

0 comments on commit 65791b8

Please sign in to comment.