diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index 8c05b9d7689..c3211a3d83b 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -2,7 +2,7 @@ from ctypes import c_int, c_void_p, sizeof from itertools import groupby, product from math import ceil, pow -from sympy import root, divisors, primefactors +from sympy import root, primefactors import atexit @@ -195,8 +195,8 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None): # Note: the cloned communicator doesn't need to be explicitly freed; # mpi4py takes care of that when the object gets out of scope self._input_comm = (input_comm or MPI.COMM_WORLD).Clone() + if topology is None: - # import pdb;pdb.set_trace() # `MPI.Compute_dims` sets the dimension sizes to be as close to each other # as possible, using an appropriate divisibility algorithm. Thus, in 3D: # * topology[0] >= topology[1] >= topology[2] @@ -621,39 +621,45 @@ def __new__(cls, items, input_comm): alloc_procs = np.prod([i for i in items if i != '*']) remprocs = int(input_comm.size // alloc_procs) - processed = [] + # If no stars exist we are ready + if nstars == 0: + processed = items # If all inputs are stars, and nstars root exists slice as evenly as possible - if nstars == len(items) and root(remprocs, nstars).is_Integer: + elif nstars == len(items) and root(remprocs, nstars).is_Integer: dd = root(remprocs, nstars) processed = as_tuple([int(dd) for i in range(nstars)]) + # Process nstars > 0 + else: + processed = [1] * len(items) - # Else decompose the domain as evenly as possible among the `star`ed dimensions - elif nstars > 0: - # If nstars root is an integer decompose remprocs evenly - if root(remprocs, nstars).is_Integer: - dd = root(remprocs, nstars) - # Otherwise prioritize splitting the outermost dimension - else: - # If we cannot decompose to even number of slices per dimension, - # decompose the outermost with the prime factor + # Get star and ints positions + int_pos = [i for i, item in enumerate(items) if isinstance(item, int)] + int_vals = [item for item in items if isinstance(item, int)] + star_pos = [i for i, item in enumerate(items) if not isinstance(item, int)] + + # Decompose the processes remaining for allocation to prime factors + prime_factors = primefactors(remprocs) + + star_i = -1 + dd_list = [1] * nstars + + # Start by using the max prime factor at the first starred position, + # then cyclically-iteratively decompose as evenly as possible until + # decomposing to the number of `remprocs` + while remprocs != 1: + star_i = star_i + 1 + star_i = star_i % nstars prime_factors = primefactors(remprocs) - if max(prime_factors) > 2: - dd = max(prime_factors) - else: - divs = divisors(remprocs) - dd = divs[(len(divs)) // nstars] - - for i in items: - if isinstance(i, int): - processed.append(i) - elif remprocs % dd == 0: - processed.append(dd) - remprocs = remprocs // dd - else: - processed.append(remprocs) - remprocs = 1 - else: - processed = items + dd_list[star_i] = dd_list[star_i]*max(prime_factors) + remprocs = remprocs // max(prime_factors) + + if int_pos: + for index, value in zip(int_pos, int_vals): + processed[index] = value + + if dd_list: + for index, value in zip(star_pos, dd_list): + processed[index] = value # Final check that topology matches the communicator size try: diff --git a/tests/test_mpi.py b/tests/test_mpi.py index afb1b53cb30..3ec4ac9af61 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -203,12 +203,13 @@ def test_custom_topology(self): (6, ('*', '*', 1), (3, 2, 1)), (6, (1, '*', '*'), (1, 3, 2)), (6, ('*', '*', '*'), (3, 2, 1)), - (12, ('*', '*', '*'), (3, 4, 1)), + (12, ('*', '*', '*'), (3, 2, 2)), (12, ('*', 3, '*'), (2, 3, 2)), (18, ('*', '*', '*'), (3, 3, 2)), (18, ('*', '*', 9), (2, 1, 9)), (18, ('*', '*', 3), (3, 2, 3)), - (24, ('*', '*', '*'), (3, 8, 1)), + (24, ('*', '*', '*'), (6, 2, 2)), + (32, ('*', '*', '*'), (4, 4, 2)), (8, ('*', 1, '*'), (4, 1, 2)), (8, ('*', '*', 1), (4, 2, 1)), (8, (1, '*', '*'), (1, 4, 2)), @@ -224,6 +225,7 @@ def test_custom_topology(self): (64, ('*', 2, 1), (32, 2, 1)), (64, ('*', 2, 4), (8, 2, 4)), (128, ('*', '*', 1), (16, 8, 1)), + (231, ('*', '*', '*'), (11, 7, 3)), (256, (1, '*', '*'), (1, 16, 16)), (256, ('*', 1, '*'), (16, 1, 16)), (256, ('*', '*', 1), (16, 16, 1)),