Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SubDomainSet bug #1457

Merged
merged 2 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions devito/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def __str__(self):
@_check_idx
def __getitem__(self, glb_idx, comm_type, gather_rank=None):
loc_idx = self._index_glb_to_loc(glb_idx)
gather = True if isinstance(gather_rank, int) else False
if comm_type is index_by_index or gather:
is_gather = True if isinstance(gather_rank, int) else False
if comm_type is index_by_index or is_gather:
# Retrieve the pertinent local data prior to mpi send/receive operations
data_idx = loc_data_idx(loc_idx)
self._index_stash = flip_idx(glb_idx, self._decomposition)
Expand All @@ -197,7 +197,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
self._distributor.all_coords, comm)

it = np.nditer(owners, flags=['refs_ok', 'multi_index'])
if not gather:
if not is_gather:
retval = Data(local_val.shape, local_val.dtype.type,
decomposition=local_val._decomposition,
modulo=(False,)*len(local_val.shape))
Expand All @@ -208,11 +208,11 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
# Iterate over each element of data
while not it.finished:
index = it.multi_index
send_rank = gather_rank if gather else send[index]
send_rank = gather_rank if is_gather else send[index]
if rank == owners[index] and rank == send_rank:
# Current index and destination index are on the same rank
loc_ind = local_si[index]
if gather:
if is_gather:
loc_ind = local_si[index]
retval[global_si[index]] = local_val.data[loc_ind]
else:
Expand All @@ -222,7 +222,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
# Current index is on this rank and hence need to send
# the data to the appropriate rank
loc_ind = local_si[index]
send_rank = gather_rank if gather else send[index]
send_rank = gather_rank if is_gather else send[index]
send_ind = global_si[index]
send_val = local_val.data[loc_ind]
reqs = comm.isend([send_ind, send_val], dest=send_rank)
Expand All @@ -231,7 +231,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
# Current rank is required to receive data from this index
recval = comm.irecv(source=owners[index])
local_dat = recval.wait()
if gather:
if is_gather:
retval[local_dat[0]] = local_dat[1]
else:
loc_ind = local_si[local_dat[0]]
Expand Down
9 changes: 3 additions & 6 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def _add_implicit(cls, expressions):
but instead are requisites of some specified functionality.
"""
processed = []
seen = set()
for e in expressions:
if e.subdomain:
try:
Expand All @@ -263,11 +262,9 @@ def _add_implicit(cls, expressions):
sub_dims.append(e.subdomain.implicit_dimension)
dims = [d for d in dims if d not in frozenset(sub_dims)]
dims.append(e.subdomain.implicit_dimension)
if e.subdomain not in seen:
grid = list(retrieve_functions(e, mode='unique'))[0].grid
processed.extend([i.func(*i.args, implicit_dims=dims) for i in
e.subdomain._create_implicit_exprs(grid)])
seen.add(e.subdomain)
grid = list(retrieve_functions(e, mode='unique'))[0].grid
processed.extend([i.func(*i.args, implicit_dims=dims) for i in
e.subdomain._create_implicit_exprs(grid)])
dims.extend(e.subdomain.dimensions)
new_e = Eq(e.lhs, e.rhs, subdomain=e.subdomain, implicit_dims=dims)
processed.append(new_e)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ def test_sliced_gather_2D(self, start, stop, step):
(1, 8, 3),
((0, 4, 4), None, (2, 1, 1))])
def test_sliced_gather_3D(self, start, stop, step):
""" Test gather for various 2D slices."""
""" Test gather for various 3D slices."""
grid = Grid(shape=(10, 10, 10), extent=(9, 9, 9))
f = Function(name='f', grid=grid, dtype=np.int32)
dat = np.arange(1000).reshape(grid.shape)
Expand Down
62 changes: 62 additions & 0 deletions tests/test_subdomains.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,65 @@ class Inner(SubDomainSet):
fex.data[:] = np.transpose(expected)

assert((np.array(result) == np.array(fex.data[:])).all())

def test_multi_sets_eq(self):
"""
Check functionality for when multiple subdomain sets are present, each
with multiple equations.
"""

Nx = 10
Ny = Nx
n_domains = 2

n = Dimension(name='n')
m = Dimension(name='m')

class MySubdomains1(SubDomainSet):
name = 'mydomains1'
implicit_dimension = n

class MySubdomains2(SubDomainSet):
name = 'mydomains2'
implicit_dimension = m

bounds_xm = np.array([1, Nx/2+1], dtype=np.int32)
bounds_xM = np.array([Nx/2+1, 1], dtype=np.int32)
bounds_ym = int(1)
bounds_yM = int(Ny/2+1)
bounds1 = (bounds_xm, bounds_xM, bounds_ym, bounds_yM)

bounds_xm = np.array([1, Nx/2+1], dtype=np.int32)
bounds_xM = np.array([Nx/2+1, 1], dtype=np.int32)
bounds_ym = int(Ny/2+1)
bounds_yM = int(1)
bounds2 = (bounds_xm, bounds_xM, bounds_ym, bounds_yM)

my_sd1 = MySubdomains1(N=n_domains, bounds=bounds1)
my_sd2 = MySubdomains2(N=n_domains, bounds=bounds2)

grid = Grid(extent=(Nx, Ny), shape=(Nx, Ny), subdomains=(my_sd1, my_sd2))

f = Function(name='f', grid=grid, dtype=np.int32)
g = Function(name='g', grid=grid, dtype=np.int32)

eq1 = Eq(f, f+2, subdomain=grid.subdomains['mydomains1'])
eq2 = Eq(g, g+2, subdomain=grid.subdomains['mydomains2'])
eq3 = Eq(f, f-1, subdomain=grid.subdomains['mydomains1'])
eq4 = Eq(g, g+1, subdomain=grid.subdomains['mydomains2'])

op = Operator([eq1, eq2, eq3, eq4])
op.apply()

expected = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
[0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
[0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
[0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
[0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=np.int32)

assert((np.array(f.data[:]+g.data[:]) == expected).all())