Skip to content

Commit

Permalink
Fix deterministic behavior in ddp_spawn (Lightning-AI#3573)
Browse files Browse the repository at this point in the history
* docs

* set env variable

* fix

* changelog
  • Loading branch information
awaelchli authored Sep 20, 2020
1 parent 9acee67 commit a71d62d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed dataloader shuffling not getting turned off with `overfit_batches > 0` and `distributed_backend = "ddp"` ([#3534](https://github.com/PyTorchLightning/pytorch-lightning/pull/3534))

- Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335))

## [0.9.0] - YYYY-MM-DD

Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/accelerators/ddp_base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import os
import re
import torch

Expand All @@ -22,6 +22,7 @@
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.seed import seed_everything

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -97,6 +98,11 @@ def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offs
Returns:
"""
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))


# offset the process id if requested
process_idx = process_idx + proc_offset

Expand Down
18 changes: 13 additions & 5 deletions pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,24 @@


def seed_everything(seed: Optional[int] = None) -> int:
"""Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random and sets PYTHONHASHSEED environment variable.
"""
Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random and sets PYTHONHASHSEED environment variable.
In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to
spawned subprocesses (e.g. ddp_spawn backend).
Args:
seed: the integer value seed for global random state in Lightning.
If `None`, will read seed from `PL_GLOBAL_SEED` env variable
or select it randomly.
"""
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min

try:
if seed is None:
seed = _select_seed_randomly(min_seed_value, max_seed_value)
else:
seed = int(seed)
seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value))
seed = int(seed)
except (TypeError, ValueError):
seed = _select_seed_randomly(min_seed_value, max_seed_value)

Expand All @@ -47,6 +54,7 @@ def seed_everything(seed: Optional[int] = None) -> int:
seed = _select_seed_randomly(min_seed_value, max_seed_value)

os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
Expand Down

0 comments on commit a71d62d

Please sign in to comment.