From e1356ed84f359ec5e5cd0ee38077b02386b70927 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Nov 2023 22:14:46 +0100 Subject: [PATCH] Fix comm initialization in `MPIEnvironment` (#19074) (cherry picked from commit 197b22586a449fd24d6835c2b186d7900de05c93) --- .github/workflows/ci-tests-fabric.yml | 1 + src/lightning/fabric/CHANGELOG.md | 4 ++++ src/lightning/fabric/plugins/environments/mpi.py | 4 ++-- src/lightning/pytorch/CHANGELOG.md | 3 +++ tests/tests_fabric/plugins/environments/test_mpi.py | 2 ++ 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 635c77f479fc4..0a142bdea83db 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -164,6 +164,7 @@ jobs: working-directory: tests/tests_fabric # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 run: | + echo $GITHUB_RUN_ID python -m coverage run --source ${{ env.COVERAGE_SCOPE }} \ -m pytest -v --timeout=30 --durations=50 --random-order-seed=$GITHUB_RUN_ID diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index c16b887ee6ec7..dac37761f3518 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -17,6 +17,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074)) + + + ## [2.1.2] - 2023-11-15 ### Fixed diff --git a/src/lightning/fabric/plugins/environments/mpi.py b/src/lightning/fabric/plugins/environments/mpi.py index e40fe8b027790..bcc4122324a19 100644 --- a/src/lightning/fabric/plugins/environments/mpi.py +++ b/src/lightning/fabric/plugins/environments/mpi.py @@ -107,9 +107,9 @@ def _get_main_port(self) -> int: def _init_comm_local(self) -> None: hostname = socket.gethostname() - all_hostnames = self._comm_world.gather(hostname, root=0) + all_hostnames = self._comm_world.gather(hostname, root=0) # returns None on non-root ranks # sort all the hostnames, and find unique ones - unique_hosts = sorted(set(all_hostnames)) + unique_hosts = sorted(set(all_hostnames)) if all_hostnames is not None else [] unique_hosts = self._comm_world.bcast(unique_hosts, root=0) # find the index for this host in the list of hosts: self._node_rank = unique_hosts.index(hostname) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 064bab732da92..70de07d36e4cf 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an edge case where `ModelCheckpoint` would alternate between versioned and unversioned filename ([#19064](https://github.com/Lightning-AI/lightning/pull/19064)) +- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074)) + + ## [2.1.2] - 2023-11-15 ### Fixed diff --git a/tests/tests_fabric/plugins/environments/test_mpi.py b/tests/tests_fabric/plugins/environments/test_mpi.py index 2e283ea801ca9..649d4dcb1dab2 100644 --- a/tests/tests_fabric/plugins/environments/test_mpi.py +++ b/tests/tests_fabric/plugins/environments/test_mpi.py @@ -80,11 +80,13 @@ def test_init_local_comm(monkeypatch): env = MPIEnvironment() hostname_mock.return_value = "host1" + env._comm_world.gather.return_value = ["host1", "host2"] env._comm_world.bcast.return_value = ["host1", "host2"] assert env.node_rank() == 0 env._node_rank = None hostname_mock.return_value = "host2" + env._comm_world.gather.return_value = None env._comm_world.bcast.return_value = ["host1", "host2"] assert env.node_rank() == 1