From 2af9944d7dba122cd5e239184c3bdc987ee650f0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 11 Aug 2025 06:44:16 +0200 Subject: [PATCH 1/6] fix implementation --- src/lightning/pytorch/utilities/parsing.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 16eef555291bd..874d1d75441d8 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -167,7 +167,14 @@ def save_hyperparameters( if given_hparams is not None: init_args = given_hparams elif is_dataclass(obj): - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} + obj_fields = fields(obj) + init_args = {f.name: getattr(obj, f.name) for f in obj_fields if f.init} + if any(not f.init for f in obj_fields): + rank_zero_warn( + "Detected a dataclass with fields with `init=False`. This is not supported by `save_hyperparameters`" + " and will not save those fields. Consider removing `init=False` and just re-initialize the attributes" + " in the `__post_init__` method of the dataclass." + ) else: init_args = {} From 0869293367edbcdb67ffa18446f200b87171b2a8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 11 Aug 2025 06:44:34 +0200 Subject: [PATCH 2/6] add testing --- tests/tests_pytorch/models/test_hparams.py | 28 +++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index f14d62b6befb4..feae5aead32cb 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -17,7 +17,7 @@ import pickle import sys from argparse import Namespace -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from unittest import mock @@ -881,6 +881,32 @@ def test_dataclass_lightning_module(tmp_path): assert model.hparams == {"mandatory": 33, "optional": "cocofruit"} +def test_dataclass_with_init_false_fields(): + """Test that save_hyperparameters() filters out fields with init=False and issues a warning.""" + + @dataclass + class DataClassWithInitFalseFieldsModel(BoringModel): + mandatory: int + optional: str = "optional" + non_init_field: int = field(default=999, init=False) + another_non_init: str = field(default="not_in_init", init=False) + + def __post_init__(self): + super().__init__() + self.save_hyperparameters() + + with pytest.warns(UserWarning, match="Detected a dataclass with fields with `init=False`"): + model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit") + + expected_hparams = {"mandatory": 33, "optional": "cocofruit"} + assert model.hparams == expected_hparams + + assert model.non_init_field == 999 + assert model.another_non_init == "not_in_init" + assert "non_init_field" not in model.hparams + assert "another_non_init" not in model.hparams + + class NoHparamsModel(BoringModel): """Tests a model without hparams.""" From 057ede06ccc30e9751e2295a8ec514f1a2f2f98c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 11 Aug 2025 07:56:17 +0200 Subject: [PATCH 3/6] small change to wording --- src/lightning/pytorch/utilities/parsing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 874d1d75441d8..f07eced80c0e2 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -172,8 +172,8 @@ def save_hyperparameters( if any(not f.init for f in obj_fields): rank_zero_warn( "Detected a dataclass with fields with `init=False`. This is not supported by `save_hyperparameters`" - " and will not save those fields. Consider removing `init=False` and just re-initialize the attributes" - " in the `__post_init__` method of the dataclass." + " and will not save those fields in `self.hparams`. Consider removing `init=False` and just" + " re-initialize the attributes in the `__post_init__` method of the dataclass." ) else: init_args = {} From 82dba6f7c04f55c5a4d84165eba30bfa54619b32 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 11 Aug 2025 07:58:39 +0200 Subject: [PATCH 4/6] changelog --- src/lightning/pytorch/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 01bbc0ff03fb0..dc007aed8961d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960)) +- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051)) + + --- ## [2.5.2] - 2025-06-20 From 19acad92742287ded746550f5a102800d5af8558 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 11 Aug 2025 11:33:53 +0200 Subject: [PATCH 5/6] Update src/lightning/pytorch/CHANGELOG.md Co-authored-by: Quentin Soubeyran <95291314+QuentinSoubeyranAqemia@users.noreply.github.com> --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index dc007aed8961d..a2e9ded0aeb2c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -41,7 +41,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960)) -- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051)) +- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` fields ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051)) --- From 4272b983c01b430e1ac6cacdd75d3cc018886d92 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 12 Aug 2025 06:36:11 +0200 Subject: [PATCH 6/6] fix based on feedback --- src/lightning/pytorch/utilities/parsing.py | 6 ------ tests/tests_pytorch/models/test_hparams.py | 3 +-- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index f07eced80c0e2..829cc7a994b93 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -169,12 +169,6 @@ def save_hyperparameters( elif is_dataclass(obj): obj_fields = fields(obj) init_args = {f.name: getattr(obj, f.name) for f in obj_fields if f.init} - if any(not f.init for f in obj_fields): - rank_zero_warn( - "Detected a dataclass with fields with `init=False`. This is not supported by `save_hyperparameters`" - " and will not save those fields in `self.hparams`. Consider removing `init=False` and just" - " re-initialize the attributes in the `__post_init__` method of the dataclass." - ) else: init_args = {} diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index feae5aead32cb..575bcadadc404 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -895,8 +895,7 @@ def __post_init__(self): super().__init__() self.save_hyperparameters() - with pytest.warns(UserWarning, match="Detected a dataclass with fields with `init=False`"): - model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit") + model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit") expected_hparams = {"mandatory": 33, "optional": "cocofruit"} assert model.hparams == expected_hparams