Skip to content

Commit

Permalink
mpi: Simplify algorithm logic for decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed May 22, 2023
1 parent bae6479 commit 0a8ecf0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
64 changes: 35 additions & 29 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None):
# Note: the cloned communicator doesn't need to be explicitly freed;
# mpi4py takes care of that when the object gets out of scope
self._input_comm = (input_comm or MPI.COMM_WORLD).Clone()

if topology is None:
# import pdb;pdb.set_trace()
# `MPI.Compute_dims` sets the dimension sizes to be as close to each other
# as possible, using an appropriate divisibility algorithm. Thus, in 3D:
# * topology[0] >= topology[1] >= topology[2]
Expand Down Expand Up @@ -621,39 +621,45 @@ def __new__(cls, items, input_comm):
alloc_procs = np.prod([i for i in items if i != '*'])
remprocs = int(input_comm.size // alloc_procs)

processed = []
# If no stars exist we are ready
if nstars == 0:
processed = items
# If all inputs are stars, and nstars root exists slice as evenly as possible
if nstars == len(items) and root(remprocs, nstars).is_Integer:
elif nstars == len(items) and root(remprocs, nstars).is_Integer:
dd = root(remprocs, nstars)
processed = as_tuple([int(dd) for i in range(nstars)])
# Process nstars > 0
else:
processed = [1] * len(items)

# Else decompose the domain as evenly as possible among the `star`ed dimensions
elif nstars > 0:
# If nstars root is an integer decompose remprocs evenly
if root(remprocs, nstars).is_Integer:
dd = root(remprocs, nstars)
# Otherwise prioritize splitting the outermost dimension
else:
# If we cannot decompose to even number of slices per dimension,
# decompose the outermost with the prime factor
# Get star and ints positions
int_pos = [i for i, item in enumerate(items) if isinstance(item, int)]
int_vals = [item for item in items if isinstance(item, int)]
star_pos = [i for i, item in enumerate(items) if not isinstance(item, int)]

# Decompose the processes remaining for allocation to prime factors
prime_factors = primefactors(remprocs)

star_i = -1
dd_list = [1] * nstars

# Start by using the max prime factor at the first starred position,
# then cyclically-iteratively decompose as evenly as possible until decomposing
# to the number of `remprocs`
while remprocs != 1:
star_i = star_i + 1
star_i = star_i % nstars
prime_factors = primefactors(remprocs)
if max(prime_factors) > 2:
dd = max(prime_factors)
else:
divs = divisors(remprocs)
dd = divs[(len(divs)) // nstars]

for i in items:
if isinstance(i, int):
processed.append(i)
elif remprocs % dd == 0:
processed.append(dd)
remprocs = remprocs // dd
else:
processed.append(remprocs)
remprocs = 1
else:
processed = items
dd_list[star_i] = dd_list[star_i]*max(prime_factors)
remprocs = remprocs // max(prime_factors)

if int_pos:
for index, value in zip(int_pos, int_vals):
processed[index] = value

if dd_list:
for index, value in zip(star_pos, dd_list):
processed[index] = value

# Final check that topology matches the communicator size
try:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,13 @@ def test_custom_topology(self):
(6, ('*', '*', 1), (3, 2, 1)),
(6, (1, '*', '*'), (1, 3, 2)),
(6, ('*', '*', '*'), (3, 2, 1)),
(12, ('*', '*', '*'), (3, 4, 1)),
(12, ('*', '*', '*'), (3, 2, 2)),
(12, ('*', 3, '*'), (2, 3, 2)),
(18, ('*', '*', '*'), (3, 3, 2)),
(18, ('*', '*', 9), (2, 1, 9)),
(18, ('*', '*', 3), (3, 2, 3)),
(24, ('*', '*', '*'), (3, 8, 1)),
(24, ('*', '*', '*'), (6, 2, 2)),
(32, ('*', '*', '*'), (4, 4, 2)),
(8, ('*', 1, '*'), (4, 1, 2)),
(8, ('*', '*', 1), (4, 2, 1)),
(8, (1, '*', '*'), (1, 4, 2)),
Expand All @@ -224,6 +225,7 @@ def test_custom_topology(self):
(64, ('*', 2, 1), (32, 2, 1)),
(64, ('*', 2, 4), (8, 2, 4)),
(128, ('*', '*', 1), (16, 8, 1)),
(231, ('*', '*', '*'), (11, 7, 3)),
(256, (1, '*', '*'), (1, 16, 16)),
(256, ('*', 1, '*'), (16, 1, 16)),
(256, ('*', '*', 1), (16, 16, 1)),
Expand Down

0 comments on commit 0a8ecf0

Please sign in to comment.