From 65791b802a134cec31366bd9258680e70c83c942 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 17 Aug 2021 12:44:12 +0100 Subject: [PATCH] Fixes --- tests/core/test_model.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 66bcd6038e..e16d62e686 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -29,6 +29,7 @@ from torch.utils.data import DataLoader import flash +from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess from flash.core.utilities.imports import _PIL_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE @@ -125,6 +126,30 @@ def __init__(self, child): super().__init__(Parent(child)) +class BasicAdapter(Adapter): + def __init__(self, child): + super().__init__() + + self.child = child + + def training_step(self, batch, batch_idx): + return self.child.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.child.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.child.test_step(batch, batch_idx) + + def forward(self, x): + return self.child(x) + + +class AdapterParent(Parent): + def __init__(self, child): + super().__init__(BasicAdapter(child)) + + # ================================ @@ -140,7 +165,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): assert "test_nll_loss" in result[0] -@pytest.mark.parametrize("task", [Parent, GrandParent]) +@pytest.mark.parametrize("task", [Parent, GrandParent, AdapterParent]) def test_nested_tasks(tmpdir, task): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -263,7 +288,7 @@ def test_available_backbones(): class Foo(ImageClassifier): backbones = None - assert Foo.available_backbones() == [] + assert Foo.available_backbones() == {} def test_optimization(tmpdir): @@ -313,7 +338,7 @@ def test_optimization(tmpdir): scheduler_kwargs={"num_warmup_steps": 0.1}, loss_fn=F.nll_loss, ) - trainer = flash.Trainer(max_epochs=1, limit_train_batches=2) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) ds = DummyDataset() trainer.fit(task, train_dataloader=DataLoader(ds)) optimizer, scheduler = task.configure_optimizers() @@ -334,5 +359,5 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5) task = ClassificationTask(model) - trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy()) + trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count()) trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))