Skip to content

Commit

Permalink
Fix Enums parsing in generated hparms yaml (#9170)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
4 people authored Oct 25, 2021
1 parent 0e0247a commit 47e7a28
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))


- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))


- Added `ckpt_path` argument for `trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))


Expand All @@ -242,6 +245,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
Old [neptune-client](https://github.com/neptune-ai/neptune-client) API is supported by `NeptuneClient` from [neptune-contrib](https://github.com/neptune-ai/neptune-contrib) repo.


- Parsing of `enums` type hyperparameters to be saved in the `haprams.yaml` file by tensorboard and csv loggers has been fixed and made in line with how omegaconf parses it. ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))


- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))


Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
from argparse import Namespace
from copy import deepcopy
from enum import Enum
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
from warnings import warn

Expand Down Expand Up @@ -318,8 +319,8 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict
Args:
config_yaml: Path to config yaml file
use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True,
the hparams will be converted to `DictConfig` if possible
use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
the hparams will be converted to ``DictConfig`` if possible.
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_yaml = './testing-hparams.yaml'
Expand All @@ -346,11 +347,14 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict
return hparams


def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None:
"""
Args:
config_yaml: path to new YAML file
hparams: parameters to be saved
use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
the hparams will be converted to ``DictConfig`` if possible.
"""
fs = get_filesystem(config_yaml)
if not fs.isdir(os.path.dirname(config_yaml)):
Expand All @@ -363,7 +367,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
hparams = dict(hparams)

# saving with OmegaConf objects
if _OMEGACONF_AVAILABLE:
if _OMEGACONF_AVAILABLE and use_omegaconf:
# deepcopy: hparams from user shouldn't be resolved
hparams = deepcopy(hparams)
hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True)
Expand All @@ -381,6 +385,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
# drop paramaters which contain some strange datatypes as fsspec
for k, v in hparams.items():
try:
v = v.name if isinstance(v, Enum) else v
yaml.dump(v)
except TypeError:
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
Expand Down
30 changes: 25 additions & 5 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import pickle
from argparse import Namespace
from dataclasses import dataclass
from enum import Enum
from unittest import mock

import cloudpickle
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from omegaconf import Container, OmegaConf
from omegaconf.dictconfig import DictConfig
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, Trainer
Expand Down Expand Up @@ -477,22 +479,40 @@ def test_hparams_pickle_warning(tmpdir):


def test_hparams_save_yaml(tmpdir):
class Options(str, Enum):
option1name = "option1val"
option2name = "option2val"
option3name = "option3val"

hparams = dict(
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd")
batch_size=32,
learning_rate=0.001,
data_root="./any/path/here",
nested=dict(any_num=123, anystr="abcd"),
switch=Options.option3name,
)
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")

def _compare_params(loaded_params, default_params: dict):
assert isinstance(loaded_params, (dict, DictConfig))
assert loaded_params.keys() == default_params.keys()
for k, v in default_params.items():
if isinstance(v, Enum):
assert v.name == loaded_params[k]
else:
assert v == loaded_params[k]

save_hparams_to_yaml(path_yaml, hparams)
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)

save_hparams_to_yaml(path_yaml, Namespace(**hparams))
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)

save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)

save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
_compare_params(load_hparams_from_yaml(path_yaml), hparams)


class NoArgsSubClassBoringModel(CustomBoringModel):
Expand Down

0 comments on commit 47e7a28

Please sign in to comment.