Skip to content

Commit

Permalink
mpi: Simplify custom apporach using factorint and array_split
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Jun 8, 2023
1 parent 83e8640 commit 14b1fdc
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ctypes import c_int, c_void_p, sizeof
from itertools import groupby, product
from math import ceil, pow
from sympy import primefactors
from sympy import factorint

import atexit

Expand Down Expand Up @@ -634,20 +634,19 @@ def __new__(cls, items, input_comm):
alloc_procs = np.prod([i for i in items if i != '*'])
rem_procs = int(input_comm.size // alloc_procs)

# Start by using the max prime factor at the first starred position,
# then iteratively decompose as evenly as possible until decomposing
# to the number of `rem_procs`
star_vals = [1] * len(items)
star_i = 0
while rem_procs > 1:
prime_factors = primefactors(rem_procs)
rem_procs //= max(prime_factors)
star_vals[star_i] *= max(prime_factors)
star_i = (star_i + 1) % nstars

# Apply computed star values to the processed
for index, value in zip(star_pos, star_vals):
processed[index] = value
# List of all factors of rem_procs in decreasing order
factors = factorint(rem_procs)
vals = [k for (k, v) in factors.items() for _ in range(v)][::-1]

# Split in number of stars
split = np.array_split(vals, nstars)

# Reduce
star_vals = [int(np.prod(s)) for s in split]

# Apply computed star values to the processed
for index, value in zip(star_pos, star_vals):
processed[index] = value

# Final check that topology matches the communicator size
assert np.prod(processed) == input_comm.size
Expand Down

0 comments on commit 14b1fdc

Please sign in to comment.