Skip to content

Commit

Permalink
Fix comm initialization in MPIEnvironment (#19074)
Browse files Browse the repository at this point in the history
(cherry picked from commit 197b225)
  • Loading branch information
awaelchli authored and Borda committed Dec 19, 2023
1 parent 8b9b995 commit e1356ed
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ jobs:
working-directory: tests/tests_fabric
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
run: |
echo $GITHUB_RUN_ID
python -m coverage run --source ${{ env.COVERAGE_SCOPE }} \
-m pytest -v --timeout=30 --durations=50 --random-order-seed=$GITHUB_RUN_ID
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074))



## [2.1.2] - 2023-11-15

### Fixed
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/plugins/environments/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def _get_main_port(self) -> int:

def _init_comm_local(self) -> None:
hostname = socket.gethostname()
all_hostnames = self._comm_world.gather(hostname, root=0)
all_hostnames = self._comm_world.gather(hostname, root=0) # returns None on non-root ranks
# sort all the hostnames, and find unique ones
unique_hosts = sorted(set(all_hostnames))
unique_hosts = sorted(set(all_hostnames)) if all_hostnames is not None else []
unique_hosts = self._comm_world.bcast(unique_hosts, root=0)
# find the index for this host in the list of hosts:
self._node_rank = unique_hosts.index(hostname)
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an edge case where `ModelCheckpoint` would alternate between versioned and unversioned filename ([#19064](https://github.com/Lightning-AI/lightning/pull/19064))


- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074))


## [2.1.2] - 2023-11-15

### Fixed
Expand Down
2 changes: 2 additions & 0 deletions tests/tests_fabric/plugins/environments/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ def test_init_local_comm(monkeypatch):
env = MPIEnvironment()

hostname_mock.return_value = "host1"
env._comm_world.gather.return_value = ["host1", "host2"]
env._comm_world.bcast.return_value = ["host1", "host2"]
assert env.node_rank() == 0

env._node_rank = None
hostname_mock.return_value = "host2"
env._comm_world.gather.return_value = None
env._comm_world.bcast.return_value = ["host1", "host2"]
assert env.node_rank() == 1

Expand Down

0 comments on commit e1356ed

Please sign in to comment.