Skip to content

Commit

Permalink
Use PickleError base class to detect all pickle errors (#6917)
Browse files Browse the repository at this point in the history
* Use PickleError base class to detect all pickle errors

* Update changelog with #6917

* Add pickle test for torch ScriptModule

Co-authored-by: Ken Witham <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
3 people authored Apr 14, 2021
1 parent 03a73b3 commit dcff503
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))


- Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))


- Fixed multi-gpu join for Horovod ([#6954](https://github.com/PyTorchLightning/pytorch-lightning/pull/6954))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def is_picklable(obj: object) -> bool:
try:
pickle.dumps(obj)
return True
except (pickle.PicklingError, AttributeError):
except (pickle.PickleError, AttributeError):
return False


Expand Down
4 changes: 3 additions & 1 deletion tests/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import pytest

from torch.jit import ScriptModule

from pytorch_lightning.utilities.parsing import (
AttributeDict,
clean_namespace,
Expand Down Expand Up @@ -203,7 +205,7 @@ class UnpicklableClass:
pass

true_cases = [None, True, 123, "str", (123, "str"), max]
false_cases = [unpicklable_function, UnpicklableClass]
false_cases = [unpicklable_function, UnpicklableClass, ScriptModule()]

for case in true_cases:
assert is_picklable(case) is True
Expand Down

0 comments on commit dcff503

Please sign in to comment.