From 14b1fdcbd51f455ad001c2ef180b6dbda016c10c Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Thu, 8 Jun 2023 17:13:50 +0100 Subject: [PATCH] mpi: Simplify custom apporach using factorint and array_split --- devito/mpi/distributed.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index c0fee0f5a3d..947a764f9f5 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -2,7 +2,7 @@ from ctypes import c_int, c_void_p, sizeof from itertools import groupby, product from math import ceil, pow -from sympy import primefactors +from sympy import factorint import atexit @@ -634,20 +634,19 @@ def __new__(cls, items, input_comm): alloc_procs = np.prod([i for i in items if i != '*']) rem_procs = int(input_comm.size // alloc_procs) - # Start by using the max prime factor at the first starred position, - # then iteratively decompose as evenly as possible until decomposing - # to the number of `rem_procs` - star_vals = [1] * len(items) - star_i = 0 - while rem_procs > 1: - prime_factors = primefactors(rem_procs) - rem_procs //= max(prime_factors) - star_vals[star_i] *= max(prime_factors) - star_i = (star_i + 1) % nstars - - # Apply computed star values to the processed - for index, value in zip(star_pos, star_vals): - processed[index] = value + # List of all factors of rem_procs in decreasing order + factors = factorint(rem_procs) + vals = [k for (k, v) in factors.items() for _ in range(v)][::-1] + + # Split in number of stars + split = np.array_split(vals, nstars) + + # Reduce + star_vals = [int(np.prod(s)) for s in split] + + # Apply computed star values to the processed + for index, value in zip(star_pos, star_vals): + processed[index] = value # Final check that topology matches the communicator size assert np.prod(processed) == input_comm.size