Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fixed a bug where validation metrics could be aggregated together with test metrics #900

Merged
merged 3 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases ([#900](https://github.com/PyTorchLightning/lightning-flash/pull/900))

## [0.5.1] - 2021-10-26

### Added
Expand Down
3 changes: 2 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __init__(

self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
self.test_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")
Expand Down Expand Up @@ -454,7 +455,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:
)

def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx, self.val_metrics)
output = self.step(batch, batch_idx, self.test_metrics)
self.log_dict(
{f"test_{k}": v for k, v in output[OutputKeys.LOGS].items()},
on_step=False,
Expand Down
8 changes: 8 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,21 @@ def i_will_create_a_misconfiguration_exception(optimizer):
def test_classification_task_metrics():
train_dataset = FixedDataset([0, 1])
val_dataset = FixedDataset([1, 1])
test_dataset = FixedDataset([0, 0])

model = OnesModel()

class CheckAccuracy(Callback):
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["val_accuracy"], 1.0)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["test_accuracy"], 0.0)

task = ClassificationTask(model)
trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count())
trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))
trainer.test(task, dataloaders=DataLoader(test_dataset))