Skip to content

Commit

Permalink
mpi: Simplify custom approach using 'factorint' and 'array_split'
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Jun 9, 2023
1 parent 9c7e8d3 commit f26e986
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
21 changes: 10 additions & 11 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ctypes import c_int, c_void_p, sizeof
from itertools import groupby, product
from math import ceil, pow
from sympy import primefactors
from sympy import factorint

import atexit

Expand Down Expand Up @@ -634,16 +634,15 @@ def __new__(cls, items, input_comm):
alloc_procs = np.prod([i for i in items if i != '*'])
rem_procs = int(input_comm.size // alloc_procs)

# Start by using the max prime factor at the first starred position,
# then iteratively decompose as evenly as possible until decomposing
# to the number of `rem_procs`
star_vals = [1] * len(items)
star_i = 0
while rem_procs > 1:
prime_factors = primefactors(rem_procs)
rem_procs //= max(prime_factors)
star_vals[star_i] *= max(prime_factors)
star_i = (star_i + 1) % nstars
# List of all factors of rem_procs in decreasing order
factors = factorint(rem_procs)
vals = [k for (k, v) in factors.items() for _ in range(v)][::-1]

# Split in number of stars
split = np.array_split(vals, nstars)

# Reduce
star_vals = [int(np.prod(s)) for s in split]

# Apply computed star values to the processed
for index, value in zip(star_pos, star_vals):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_custom_topology(self):
(256, ('*', '*', 2), (16, 8, 2)),
(256, ('*', 32, 2), (4, 32, 2)),
])
def test_custom_topology_3d_dummy(self, comm_size, topology, dist_topology):
def test_custom_topology_v2(self, comm_size, topology, dist_topology):
dummy_comm = Bunch(size=comm_size)
custom_topology = CustomTopology(topology, dummy_comm)
assert custom_topology == dist_topology
Expand Down

0 comments on commit f26e986

Please sign in to comment.