Skip to content

Commit

Permalink
mpi: Add numbers other than 1 to custom topology
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Jun 4, 2023
1 parent 7e2d767 commit 74bfec0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 33 deletions.
61 changes: 29 additions & 32 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ctypes import c_int, c_void_p, sizeof
from itertools import groupby, product
from math import ceil, pow
from sympy import root, divisors
from sympy import root, divisors, primefactors

import atexit

Expand Down Expand Up @@ -195,8 +195,8 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None):
# Note: the cloned communicator doesn't need to be explicitly freed;
# mpi4py takes care of that when the object gets out of scope
self._input_comm = (input_comm or MPI.COMM_WORLD).Clone()

if topology is None:
# import pdb;pdb.set_trace()
# `MPI.Compute_dims` sets the dimension sizes to be as close to each other
# as possible, using an appropriate divisibility algorithm. Thus, in 3D:
# * topology[0] >= topology[1] >= topology[2]
Expand All @@ -206,8 +206,7 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None):
# guarantee that 9 ranks are arranged into a 3x3 grid when shape=(9, 9))
self._topology = compute_dims(self._input_comm.size, len(shape))
else:
# A custom topology may contain integers or the wildcard '*', which
# implies `nprocs // nstars`
# A custom topology may contain integers or the wildcard '*'
topology = CustomTopology(topology, self._input_comm)

self._topology = topology
Expand Down Expand Up @@ -610,49 +609,46 @@ class CustomTopology(tuple):
* `('*', 1, '*')` gives: (4, 1, 2)
* `(1, '*', '*')` gives: (1, 4, 2)
Raises
------
If the wildcard `'*'` is used, then the CustomTopology can only contain either
`'*'` or 1's, otherwise a ValueError exception is raised.
Notes
-----
Users shouldn't use this class directly. It's up to the Devito runtime to
instantiate it based on the user input.
"""

def __new__(cls, items, input_comm):
# Keep track of nstars and already defined decompositions
nstars = len([i for i in items if i == '*'])
alloc_procs = np.prod([i for i in items if i != '*'])
remprocs = int(input_comm.size // alloc_procs)

processed = []
# If all inputs are stars, slice as evenly as possible
if nstars == len(items) and root(input_comm.size, nstars).is_Integer:
dd = root(input_comm.size, nstars)
# If all inputs are stars, and nstars root exists slice as evenly as possible
if nstars == len(items) and root(remprocs, nstars).is_Integer:
dd = root(remprocs, nstars)
processed = as_tuple([int(dd) for i in range(nstars)])

# Else if more than one stars are present decompose the domain as evenly
# as possible among the `star`ed dimensions
# Else decompose the domain as evenly as possible among the `star`ed dimensions
elif nstars > 0:
if any(i not in ('*', 1) for i in items):
raise ValueError("Custom topology must be only 1 or *")

# If nstars are not a perfect divisor of the communicator size
# size, prioritize splitting the outermost dimension
remprocs = input_comm.size

if root(input_comm.size, nstars).is_Integer:
v = root(input_comm.size, nstars)
elif remprocs == nstars:
v = remprocs
# If nstars root is an integer decompose remprocs evenly
if root(remprocs, nstars).is_Integer:
dd = root(remprocs, nstars)
# Otherwise prioritize splitting the outermost dimension
else:
divs = divisors(remprocs)
v = divs[(len(divs)) // nstars]
# If we cannot decompose to even number of slices per dimension,
# decompose the outermost with the prime factor
prime_factors = primefactors(remprocs)
if max(prime_factors) > 2:
dd = max(prime_factors)
else:
divs = divisors(remprocs)
dd = divs[(len(divs)) // nstars]

for i in items:
if i == 1:
if isinstance(i, int):
processed.append(i)
elif remprocs % v == 0:
processed.append(v)
remprocs = remprocs // v
elif remprocs % dd == 0:
processed.append(dd)
remprocs = remprocs // dd
else:
processed.append(remprocs)
remprocs = 1
Expand All @@ -663,7 +659,8 @@ def __new__(cls, items, input_comm):
try:
assert np.prod(processed) == input_comm.size
except:
raise ValueError("Invalid `topology` for given nprocs")
raise ValueError("Invalid `topology`", processed, " for given nprocs:",
input_comm.size)

obj = super().__new__(cls, processed)
obj.logical = items
Expand Down
19 changes: 18 additions & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,31 +188,48 @@ def test_custom_topology(self):
assert f2.size_global == f.size_global

@pytest.mark.parametrize('comm_size, topology, dist_topology', [
(2, (1, '*'), (1, 2)),
(2, ('*', '*'), (2, 1)),
(1, (1, '*', '*'), (1, 1, 1)),
(2, (1, '*', '*'), (1, 2, 1)),
(2, (2, '*', '*'), (2, 1, 1)),
(3, (1, '*', '*'), (1, 3, 1)),
(3, ('*', 1, '*'), (3, 1, 1)),
(3, ('*', '*', 1), (3, 1, 1)),
(4, (2, '*', '*'), (2, 2, 1)),
(4, ('*', 2, '*'), (2, 2, 1)),
(4, ('*', '*', 2), (2, 1, 2)),
(6, ('*', 1, '*'), (3, 1, 2)),
(6, ('*', '*', 1), (3, 2, 1)),
(6, (1, '*', '*'), (1, 3, 2)),
(6, ('*', '*', '*'), (2, 3, 1)), # TOFIX as (3, 2, 1)
(6, ('*', '*', '*'), (3, 2, 1)),
(12, ('*', '*', '*'), (3, 4, 1)),
(12, ('*', 3, '*'), (2, 3, 2)),
(18, ('*', '*', '*'), (3, 3, 2)),
(18, ('*', '*', 9), (2, 1, 9)),
(18, ('*', '*', 3), (3, 2, 3)),
(24, ('*', '*', '*'), (3, 8, 1)),
(8, ('*', 1, '*'), (4, 1, 2)),
(8, ('*', '*', 1), (4, 2, 1)),
(8, (1, '*', '*'), (1, 4, 2)),
(8, ('*', '*', '*'), (2, 2, 2)),
(9, ('*', '*', '*'), (3, 3, 1)),
(11, (1, '*', '*'), (1, 11, 1)),
(22, ('*', '*', '*'), (11, 2, 1)),
(16, ('*', '*', 1), (4, 4, 1)),
(16, ('*', 1, '*'), (4, 1, 4)),
(32, ('*', '*', 1), (8, 4, 1)),
(64, ('*', '*', '*'), (4, 4, 4)),
(64, ('*', '*', 1), (8, 8, 1)),
(64, ('*', 2, 1), (32, 2, 1)),
(64, ('*', 2, 4), (8, 2, 4)),
(128, ('*', '*', 1), (16, 8, 1)),
(256, (1, '*', '*'), (1, 16, 16)),
(256, ('*', 1, '*'), (16, 1, 16)),
(256, ('*', '*', 1), (16, 16, 1)),
(256, ('*', '*', '*'), (8, 8, 4)),
(256, ('*', '*', 2), (16, 8, 2)),
(256, ('*', 32, 2), (4, 32, 2)),
])
def test_custom_topology_3d_dummy(self, comm_size, topology, dist_topology):
dummy_comm = DummyInputComm(comm_size)
Expand Down

0 comments on commit 74bfec0

Please sign in to comment.