From c835d8c6fd173cc68f201861a8f20060e1e07508 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 1 Mar 2022 11:43:50 +0000 Subject: [PATCH] Fix loss function buffer support (#1203) * Fix loss function buffer support * Update CHANGELOG.md --- CHANGELOG.md | 2 ++ flash/core/utilities/apply_func.py | 6 +++++- tests/core/test_model.py | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10582bb1ea..c9171d4c7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed DDP support for `VideoClassifier` ([#1189](https://github.com/PyTorchLightning/lightning-flash/pull/1189)) +- Fixed a bug where buffers in loss functions were not correctly registered in the `Task` ([#1203](https://github.com/PyTorchLightning/lightning-flash/pull/1203)) + ## [0.7.0] - 2022-02-15 ### Added diff --git a/flash/core/utilities/apply_func.py b/flash/core/utilities/apply_func.py index c218b23976..2b08cecbb3 100644 --- a/flash/core/utilities/apply_func.py +++ b/flash/core/utilities/apply_func.py @@ -13,12 +13,16 @@ # limitations under the License. from typing import Callable, Dict, Mapping, Sequence, Type, Union +from torch import nn + def get_callable_name(fn_or_class: Union[Callable, object]) -> str: return getattr(fn_or_class, "__name__", fn_or_class.__class__.__name__).lower() -def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Mapping]: +def get_callable_dict(fn: Union[nn.Module, Callable, Mapping, Sequence]) -> Union[Dict, Mapping]: + if isinstance(fn, nn.Module): + return nn.ModuleDict({get_callable_name(fn): fn}) if isinstance(fn, Mapping): return fn if isinstance(fn, Sequence): diff --git a/tests/core/test_model.py b/tests/core/test_model.py index ce3f62fc44..76dddeddfa 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -488,3 +488,12 @@ def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count()) trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset)) trainer.test(task, DataLoader(test_dataset)) + + +def test_loss_fn_buffer(): + weight = torch.rand(10) + model = Task(loss_fn=nn.CrossEntropyLoss(weight=weight)) + state_dict = model.state_dict() + + assert len(state_dict) == 1 + assert torch.allclose(state_dict["loss_fn.crossentropyloss.weight"], weight)