From 194c8d94f7f0b2f79038713005b83e7b924e2f46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Branchaud-Charron?= Date: Fri, 29 Oct 2021 10:55:05 -0400 Subject: [PATCH] Fix testing loop in Active Learning (#879) Co-authored-by: fr.branchaud-charron Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton Co-authored-by: Ethan Harris --- CHANGELOG.md | 2 + .../classification/integrations/baal/data.py | 5 +-- .../classification/integrations/baal/loop.py | 34 +++++++++++---- .../classification/test_active_learning.py | 42 +++++++++++++++++-- 4 files changed, 67 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58c5b77345..c8b177eb11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a bug where test metrics were not logged correctly with active learning ([#879](https://github.com/PyTorchLightning/lightning-flash/pull/879)) + ## [0.5.1] - 2021-10-26 ### Added diff --git a/flash/image/classification/integrations/baal/data.py b/flash/image/classification/integrations/baal/data.py index ac0cc1f520..c0badc5c96 100644 --- a/flash/image/classification/integrations/baal/data.py +++ b/flash/image/classification/integrations/baal/data.py @@ -165,10 +165,7 @@ def label(self, probabilities: List[torch.Tensor] = None, indices=None): uncertainties = self.heuristic.get_uncertainties(torch.cat(probabilities, dim=0)) indices = np.argsort(uncertainties) if self._dataset is not None: - unlabelled_mask = self._dataset.labelled == False # noqa E712 - unlabelled = self._dataset.labelled[unlabelled_mask] - unlabelled[indices[-self.query_size :]] = True - self._dataset.labelled[unlabelled_mask] = unlabelled + self._dataset.label(indices[-self.query_size :]) def state_dict(self) -> Dict[str, torch.Tensor]: return self._dataset.state_dict() diff --git a/flash/image/classification/integrations/baal/loop.py b/flash/image/classification/integrations/baal/loop.py index 87a89d87c5..f71a4d41c0 100644 --- a/flash/image/classification/integrations/baal/loop.py +++ b/flash/image/classification/integrations/baal/loop.py @@ -19,7 +19,7 @@ from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.progress import Progress -from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus import flash from flash.core.data.utils import _STAGES_PREFIX @@ -83,6 +83,8 @@ def on_advance_start(self, *args: Any, **kwargs: Any) -> None: if self.trainer.datamodule.has_labelled_data: self._reset_dataloader_for_stage(RunningStage.TRAINING) self._reset_dataloader_for_stage(RunningStage.VALIDATING) + if self.trainer.datamodule.has_test: + self._reset_dataloader_for_stage(RunningStage.TESTING) if self.trainer.datamodule.has_unlabelled_data: self._reset_dataloader_for_stage(RunningStage.PREDICTING) self.progress.increment_ready() @@ -94,7 +96,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.fit_loop.run() if self.trainer.datamodule.has_test: - self.trainer.test_loop.run() + self._reset_testing() + metrics = self.trainer.test_loop.run() + if metrics: + self.trainer.logger.log_metrics(metrics[0], step=self.trainer.global_step) if self.trainer.datamodule.has_unlabelled_data: self._reset_predicting() @@ -133,6 +138,7 @@ def _reset_fitting(self): self.trainer.training = True self.trainer.lightning_module.on_train_dataloader() self.trainer.accelerator.connect(self._lightning_module) + self.fit_loop.epoch_progress = Progress() def _reset_predicting(self): self.trainer.state.fn = TrainerFn.PREDICTING @@ -140,12 +146,22 @@ def _reset_predicting(self): self.trainer.lightning_module.on_predict_dataloader() self.trainer.accelerator.connect(self.inference_model) + def _reset_testing(self): + self.trainer.state.fn = TrainerFn.TESTING + self.trainer.state.status = TrainerStatus.RUNNING + self.trainer.testing = True + self.trainer.lightning_module.on_test_dataloader() + self.trainer.accelerator.connect(self._lightning_module) + def _reset_dataloader_for_stage(self, running_state: RunningStage): dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader" - setattr( - self.trainer.lightning_module, - dataloader_name, - _PatchDataLoader(getattr(self.trainer.datamodule, dataloader_name)(), running_state), - ) - setattr(self.trainer, dataloader_name, None) - getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module) + # If the dataloader exists, we reset it. + dataloader = getattr(self.trainer.datamodule, dataloader_name, None) + if dataloader: + setattr( + self.trainer.lightning_module, + dataloader_name, + _PatchDataLoader(dataloader(), running_state), + ) + setattr(self.trainer, dataloader_name, None) + getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module) diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index 7a56729482..725974c595 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import torch from pytorch_lightning import seed_everything from torch import nn from torch.utils.data import SequentialSampler @@ -94,12 +95,21 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, serializer=Probabilities() ) trainer = flash.Trainer(max_epochs=3) - - active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1) + active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1, inference_iteration=3) active_learning_loop.connect(trainer.fit_loop) trainer.fit_loop = active_learning_loop - trainer.finetune(model, datamodule=active_learning_dm, strategy="freeze") + trainer.finetune(model, datamodule=active_learning_dm, strategy="no_freeze") + # Check that all metrics are logged + assert all( + any(m in log_met for log_met in active_learning_loop.trainer.logged_metrics) for m in ("train", "val", "test") + ) + + # Check that the weights has changed for both module. + classifier = active_learning_loop._lightning_module.adapter.parameters() + mc_inference = active_learning_loop.inference_model.parent_module.parameters() + assert all(torch.equal(p1, p2) for p1, p2 in zip(classifier, mc_inference)) + if initial_num_labels == 0: assert len(active_learning_dm._dataset) == 15 else: @@ -117,3 +127,29 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s else: # in the second scenario we have more labelled data! assert len(active_learning_dm.val_dataloader()) == 5 + + +@pytest.mark.skipif(not (_IMAGE_TESTING and _BAAL_AVAILABLE), reason="image and baal libraries aren't installed.") +def test_no_validation_loop(simple_datamodule): + active_learning_dm = ActiveLearningDataModule( + simple_datamodule, + initial_num_labels=2, + query_size=100, + val_split=0.0, + ) + assert active_learning_dm.val_dataloader is None + head = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(512, active_learning_dm.num_classes), + ) + + model = ImageClassifier( + backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, serializer=Probabilities() + ) + trainer = flash.Trainer(max_epochs=3) + active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1, inference_iteration=3) + active_learning_loop.connect(trainer.fit_loop) + trainer.fit_loop = active_learning_loop + + # Check that we can finetune without val_set + trainer.finetune(model, datamodule=active_learning_dm, strategy="no_freeze")