Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-she authored Oct 26, 2020
2 parents 016c10a + f07ee33 commit 65e1256
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,4 @@ mlruns/
*.ckpt
pytorch\ lightning
test-reports/
wandb
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,45 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `dirpath` and `filename` parameter in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))


- Added plugins docs and DDPPlugin to customize ddp across all accelerators([#4258](https://github.com/PyTorchLightning/pytorch-lightning/pull/4285))


- Added `strict` option to the scheduler dictionary ([#3586](https://github.com/PyTorchLightning/pytorch-lightning/pull/3586))

- Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344))

- Added `fsspec` support for profilers ([#4162](https://github.com/PyTorchLightning/pytorch-lightning/pull/4162))


### Changed


- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))


- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130))


- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273))


- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))


### Deprecated


- Deprecated `filepath` in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))


- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))


### Removed



### Fixed

- Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297))
Expand Down
25 changes: 25 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,31 @@ def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:

return params

@staticmethod
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
Args:
params: Dictionary containing the hyperparameters
Returns:
dictionary with all callables sanitized
"""
def _sanitize_callable(val):
# Give them one chance to return a value. Don't go rabbit hole of recursive call
if isinstance(val, Callable):
try:
_val = val()
if isinstance(_val, Callable):
return val.__name__
return _val
except Exception:
return val.__name__
return val

return {key: _sanitize_callable(val) for key, val in params.items()}

@staticmethod
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
"""
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
params = self._sanitize_callable_params(params)
self.experiment.config.update(params, allow_val_change=True)

@rank_zero_only
Expand Down
29 changes: 29 additions & 0 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import os
import pickle
from unittest import mock
from argparse import ArgumentParser
import types

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
Expand Down Expand Up @@ -109,3 +111,30 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):

assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}


def test_wandb_sanitize_callable_params(tmpdir):
"""
Callback function are not serializiable. Therefore, we get them a chance to return
something and if the returned type is not accepted, return None.
"""
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)

def return_something():
return "something"
params.something = return_something

def wrapper_something():
return return_something
params.wrapper_something = wrapper_something

assert isinstance(params.gpus, types.FunctionType)
params = WandbLogger._convert_params(params)
params = WandbLogger._flatten_dict(params)
params = WandbLogger._sanitize_callable_params(params)
assert params["gpus"] == '_gpus_arg_default'
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"

0 comments on commit 65e1256

Please sign in to comment.