Skip to content

Commit

Permalink
Fix reset_seed() converting the PL_SEED_WORKERS environment varia…
Browse files Browse the repository at this point in the history
…ble `str` read to `bool` (#10099)

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
3 people authored Oct 28, 2021
1 parent 9af1dd7 commit 83d74bb
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed creation of `dirpath` in `BaseProfiler` if it doesn't exist ([#10073](https://github.com/PyTorchLightning/pytorch-lightning/pull/10073))


- Fixed an issue with `pl.utilities.seed.reset_seed` converting the `PL_SEED_WORKERS` environment variable to `bool` ([#10099](https://github.com/PyTorchLightning/pytorch-lightning/pull/10099))



## [1.4.9] - 2021-09-30

- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def reset_seed() -> None:
If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing.
"""
seed = os.environ.get("PL_GLOBAL_SEED", None)
workers = os.environ.get("PL_SEED_WORKERS", False)
workers = os.environ.get("PL_SEED_WORKERS", "0")
if seed is not None:
seed_everything(int(seed), workers=bool(workers))
seed_everything(int(seed), workers=bool(int(workers)))


def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
Expand Down
13 changes: 10 additions & 3 deletions tests/utilities/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,19 @@ def test_reset_seed_no_op():
assert "PL_GLOBAL_SEED" not in os.environ


def test_reset_seed_everything():
@pytest.mark.parametrize("workers", (True, False))
def test_reset_seed_everything(workers):
"""Test that we can reset the seed to the initial value set by seed_everything()"""
assert "PL_GLOBAL_SEED" not in os.environ
seed_utils.seed_everything(123)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert "PL_SEED_WORKERS" not in os.environ

seed_utils.seed_everything(123, workers)
before = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))

seed_utils.reset_seed()
after = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
assert torch.allclose(before, after)

0 comments on commit 83d74bb

Please sign in to comment.