From d5ae9ec568e3b02a5e28f40cac3602cf1bc03ee2 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Sat, 13 Jul 2024 02:54:04 +0530 Subject: [PATCH] Make numpy an optional dependency in `utilities\seed.py` (#20055) Co-authored-by: awaelchli --- src/lightning/fabric/CHANGELOG.md | 2 +- src/lightning/fabric/utilities/imports.py | 5 ++- src/lightning/fabric/utilities/seed.py | 44 +++++++++++++++-------- src/lightning/pytorch/CHANGELOG.md | 3 +- tests/tests_fabric/utilities/test_seed.py | 32 ++++++++++++++++- 5 files changed, 67 insertions(+), 19 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 993df019fc3bd..60277c56f0a7b 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Changed the implementation of how seeds are chosen for dataloader workers when using `seed_everything(..., workers=True)` ([#20055](https://github.com/Lightning-AI/pytorch-lightning/pull/20055)) - diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index f180d40d2c4ba..4dbd57e531859 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -17,7 +17,10 @@ import platform import sys -from lightning_utilities.core.imports import compare_version +from lightning_utilities.core.imports import RequirementCache, compare_version + +_NUMPY_AVAILABLE = RequirementCache("numpy") + _IS_WINDOWS = platform.system() == "Windows" diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index b274bce88fcdf..6cc549507c584 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -3,17 +3,19 @@ import random from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import numpy as np import torch +from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE from lightning.fabric.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) -max_seed_value = np.iinfo(np.uint32).max -min_seed_value = np.iinfo(np.uint32).min + +max_seed_value = 4294967295 # 2^32 - 1 (uint32) +min_seed_value = 0 def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: @@ -54,7 +56,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank())) os.environ["PL_GLOBAL_SEED"] = str(seed) random.seed(seed) - np.random.seed(seed) + if _NUMPY_AVAILABLE: + np.random.seed(seed) torch.manual_seed(seed) os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" @@ -91,24 +94,34 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: log.debug( f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}" ) - ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) - # use 128 bits (4 x 32-bit words) - np.random.seed(ss.generate_state(4)) - # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module - torch_ss, stdlib_ss = ss.spawn(2) - torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) - # use 128 bits expressed as an integer - stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() - random.seed(stdlib_seed) + seed_sequence = _generate_seed_sequence(base_seed, worker_id, global_rank, count=4) + torch.manual_seed(seed_sequence[0]) # torch takes a 64-bit seed + random.seed((seed_sequence[1] << 32) | seed_sequence[2]) # combine two 64-bit seeds + if _NUMPY_AVAILABLE: + np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only + + +def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]: + """Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG) + algorithm.""" + # Combine base seed, worker id and rank into a unique 64-bit number + combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank + seeds = [] + for _ in range(count): + # x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant + combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1) + seeds.append(combined_seed) + return seeds def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" states = { "torch": torch.get_rng_state(), - "numpy": np.random.get_state(), "python": python_get_rng_state(), } + if _NUMPY_AVAILABLE: + states["numpy"] = np.random.get_state() if include_cuda: states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [] return states @@ -121,6 +134,7 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: # torch.cuda rng_state is only included since v1.8. if "torch.cuda" in rng_state_dict: torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) - np.random.set_state(rng_state_dict["numpy"]) + if _NUMPY_AVAILABLE and "numpy" in rng_state_dict: + np.random.set_state(rng_state_dict["numpy"]) version, state, gauss = rng_state_dict["python"] python_set_rng_state((version, tuple(state), gauss)) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index d9eea0b147dc2..39eb6e7265cd3 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -17,7 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976)) -- +- Changed the implementation of how seeds are chosen for dataloader workers when using `seed_everything(..., workers=True)` ([#20055](https://github.com/Lightning-AI/pytorch-lightning/pull/20055)) + ### Deprecated diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index 351f6a47b74cd..bb1d3583f56a6 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -1,11 +1,18 @@ import os +import random from unittest import mock from unittest.mock import Mock import lightning.fabric.utilities +import numpy import pytest import torch -from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states +from lightning.fabric.utilities.seed import ( + _collect_rng_states, + _set_rng_states, + pl_worker_init_function, + seed_everything, +) @mock.patch.dict(os.environ, clear=True) @@ -95,3 +102,26 @@ def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock): get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old") states = _collect_rng_states() assert states["torch.cuda"] == [] + + +@pytest.mark.parametrize(("num_workers", "num_ranks"), [(64, 64)]) +@pytest.mark.parametrize("base_seed", [100, 1024, 2**32 - 1]) +def test_pl_worker_init_function(base_seed, num_workers, num_ranks): + """Test that Lightning's `worker_init_fn` sets unique seeds per worker/rank derived from the base seed.""" + torch_rands = set() + stdlib_rands = set() + numpy_rands = set() + + for worker_id in range(num_workers): + for rank in range(num_ranks): + seed_everything(base_seed) + pl_worker_init_function(worker_id, rank) + torch_rands.add(tuple(torch.randint(0, 1_000_000, (100,)).tolist())) + stdlib_rands.add(tuple(random.randint(0, 1_000_000) for _ in range(100))) + numpy_rands.add(tuple(numpy.random.randint(0, 1_000_000, (100,)).tolist())) + + # Assert there are no duplicates (no collisions) + assert len(torch_rands) == num_ranks * num_workers + assert len(stdlib_rands) == num_ranks * num_workers + assert len(numpy_rands) == num_ranks * num_workers + assert len(torch_rands | stdlib_rands | numpy_rands) == 3 * num_workers * num_ranks