Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make numpy an optional dependency in utilities\seed.py #20055

Merged
merged 10 commits into from
Jul 12, 2024
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Changed the implementation of how seeds are chosen for dataloader workers when using `seed_everything(..., workers=True)` ([#20055](https://github.com/Lightning-AI/pytorch-lightning/pull/20055))

-

Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import platform
import sys

from lightning_utilities.core.imports import compare_version
from lightning_utilities.core.imports import RequirementCache, compare_version

_NUMPY_AVAILABLE = RequirementCache("numpy")


_IS_WINDOWS = platform.system() == "Windows"

Expand Down
44 changes: 29 additions & 15 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import random
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch

from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE
from lightning.fabric.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn

log = logging.getLogger(__name__)

max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min

max_seed_value = 4294967295 # 2^32 - 1 (uint32)
min_seed_value = 0


def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
Expand Down Expand Up @@ -54,7 +56,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
if _NUMPY_AVAILABLE:
np.random.seed(seed)
torch.manual_seed(seed)

os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
Expand Down Expand Up @@ -91,24 +94,34 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
log.debug(
f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"
)
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
torch_ss, stdlib_ss = ss.spawn(2)
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
# use 128 bits expressed as an integer
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
random.seed(stdlib_seed)
seed_sequence = _generate_seed_sequence(base_seed, worker_id, global_rank, count=4)
torch.manual_seed(seed_sequence[0]) # torch takes a 64-bit seed
random.seed((seed_sequence[1] << 32) | seed_sequence[2]) # combine two 64-bit seeds
if _NUMPY_AVAILABLE:
np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only


def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]:
"""Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG)
algorithm."""
# Combine base seed, worker id and rank into a unique 64-bit number
combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
seeds = []
for _ in range(count):
# x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant
combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1)
seeds.append(combined_seed)
return seeds


def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
states = {
"torch": torch.get_rng_state(),
"numpy": np.random.get_state(),
"python": python_get_rng_state(),
}
if _NUMPY_AVAILABLE:
states["numpy"] = np.random.get_state()
if include_cuda:
states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else []
return states
Expand All @@ -121,6 +134,7 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
# torch.cuda rng_state is only included since v1.8.
if "torch.cuda" in rng_state_dict:
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
np.random.set_state(rng_state_dict["numpy"])
if _NUMPY_AVAILABLE and "numpy" in rng_state_dict:
np.random.set_state(rng_state_dict["numpy"])
version, state, gauss = rng_state_dict["python"]
python_set_rng_state((version, tuple(state), gauss))
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))

-
- Changed the implementation of how seeds are chosen for dataloader workers when using `seed_everything(..., workers=True)` ([#20055](https://github.com/Lightning-AI/pytorch-lightning/pull/20055))


### Deprecated

Expand Down
32 changes: 31 additions & 1 deletion tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import os
import random
from unittest import mock
from unittest.mock import Mock

import lightning.fabric.utilities
import numpy
import pytest
import torch
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
from lightning.fabric.utilities.seed import (
_collect_rng_states,
_set_rng_states,
pl_worker_init_function,
seed_everything,
)


@mock.patch.dict(os.environ, clear=True)
Expand Down Expand Up @@ -95,3 +102,26 @@ def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock):
get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old")
states = _collect_rng_states()
assert states["torch.cuda"] == []


@pytest.mark.parametrize(("num_workers", "num_ranks"), [(64, 64)])
@pytest.mark.parametrize("base_seed", [100, 1024, 2**32 - 1])
def test_pl_worker_init_function(base_seed, num_workers, num_ranks):
"""Test that Lightning's `worker_init_fn` sets unique seeds per worker/rank derived from the base seed."""
torch_rands = set()
stdlib_rands = set()
numpy_rands = set()

for worker_id in range(num_workers):
for rank in range(num_ranks):
seed_everything(base_seed)
pl_worker_init_function(worker_id, rank)
torch_rands.add(tuple(torch.randint(0, 1_000_000, (100,)).tolist()))
stdlib_rands.add(tuple(random.randint(0, 1_000_000) for _ in range(100)))
numpy_rands.add(tuple(numpy.random.randint(0, 1_000_000, (100,)).tolist()))

# Assert there are no duplicates (no collisions)
assert len(torch_rands) == num_ranks * num_workers
assert len(stdlib_rands) == num_ranks * num_workers
assert len(numpy_rands) == num_ranks * num_workers
assert len(torch_rands | stdlib_rands | numpy_rands) == 3 * num_workers * num_ranks
Loading