Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

builtins: Support batched initialize_function #2176

Merged
merged 2 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 95 additions & 63 deletions devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,66 @@ def fset(f, g):
return f


def _initialize_function(function, data, nbl, mapper=None, mode='constant'):
"""
Construct the symbolic objects for `initialize_function`.
"""
nbl, slices = nbl_to_padsize(nbl, function.ndim)
if isinstance(data, dv.Function):
function.data[slices] = data.data[:]
else:
function.data[slices] = data
lhs = []
rhs = []
options = []

if mode == 'reflect' and function.grid.distributor.is_parallel:
# Check that HALO size is appropriate
halo = function.halo
local_size = function.shape

def buff(i, j):
return [(i + k - 2*max(max(nbl))) for k in j]

b = [min(l) for l in (w for w in (buff(i, j) for i, j in zip(local_size, halo)))]
if any(np.array(b) < 0):
raise ValueError("Function `%s` halo is not sufficiently thick." % function)

for d, (nl, nr) in zip(function.space_dimensions, as_tuple(nbl)):
dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name, parent=d, thickness=nl)
dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name, parent=d, thickness=nr)
if mode == 'constant':
subsl = nl
subsr = d.symbolic_max - nr
elif mode == 'reflect':
subsl = 2*nl - 1 - dim_l
subsr = 2*(d.symbolic_max - nr) + 1 - dim_r
else:
raise ValueError("Mode not available")
lhs.append(function.subs({d: dim_l}))
lhs.append(function.subs({d: dim_r}))
rhs.append(function.subs({d: subsl}))
rhs.append(function.subs({d: subsr}))
options.extend([None, None])

if mapper and d in mapper.keys():
exprs = mapper[d]
lhs_extra = exprs['lhs']
rhs_extra = exprs['rhs']
lhs.extend(as_list(lhs_extra))
rhs.extend(as_list(rhs_extra))
options_extra = exprs.get('options', len(as_list(lhs_extra))*[None, ])
if isinstance(options_extra, list):
options.extend(options_extra)
else:
options.extend([options_extra])

if all(options is None for i in options):
options = None

return lhs, rhs, options


def initialize_function(function, data, nbl, mapper=None, mode='constant',
name=None, pad_halo=True, **kwargs):
"""
Expand All @@ -225,9 +285,9 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant',

Parameters
----------
function : Function
function : Function or list of Functions
The initialised object.
data : ndarray or Function
data : ndarray or Function or list of ndarray/Function
The data used for initialisation.
nbl : int or tuple of int or tuple of tuple of int
Number of outer layers (such as absorbing layers for boundary damping).
Expand Down Expand Up @@ -286,73 +346,45 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant',
[2, 3, 3, 3, 3, 2],
[2, 2, 2, 2, 2, 2]], dtype=int32)
"""
name = name or 'pad_%s' % function.name
if isinstance(function, dv.TimeFunction):
if isinstance(function, (list, tuple)):
if not isinstance(data, (list, tuple)):
raise TypeError("Expected a list of `data`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could support this case for init([f,g,h], data) (ie all function initialized to smae data) with just data = [data]*len(function)

elif len(function) != len(data):
raise ValueError("Expected %d `data` items, got %d" %
(len(function), len(data)))

if mapper is not None:
raise NotImplementedError("Unsupported `mapper` with batching")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not even one mapper per function?


functions = function
datas = data
else:
functions = (function,)
datas = (data,)

if any(isinstance(f, dv.TimeFunction) for f in functions):
raise NotImplementedError("TimeFunctions are not currently supported.")

if nbl == 0:
if isinstance(data, dv.Function):
function.data[:] = data.data[:]
else:
function.data[:] = data[:]
if pad_halo:
pad_outhalo(function)
return

nbl, slices = nbl_to_padsize(nbl, function.ndim)
if isinstance(data, dv.Function):
function.data[slices] = data.data[:]
for f, data in zip(functions, datas):
if isinstance(data, dv.Function):
f.data[:] = data.data[:]
else:
f.data[:] = data[:]
else:
function.data[slices] = data
lhs = []
rhs = []
options = []

if mode == 'reflect' and function.grid.distributor.is_parallel:
# Check that HALO size is appropriate
halo = function.halo
local_size = function.shape

def buff(i, j):
return [(i + k - 2*max(max(nbl))) for k in j]
lhss, rhss, optionss = [], [], []
for f, data in zip(functions, datas):
lhs, rhs, options = _initialize_function(f, data, nbl, mapper, mode)

b = [min(l) for l in (w for w in (buff(i, j) for i, j in zip(local_size, halo)))]
if any(np.array(b) < 0):
raise ValueError("Function `%s` halo is not sufficiently thick." % function)
lhss.extend(lhs)
rhss.extend(rhs)
optionss.extend(options)

for d, (nl, nr) in zip(function.space_dimensions, as_tuple(nbl)):
dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name, parent=d, thickness=nl)
dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name, parent=d, thickness=nr)
if mode == 'constant':
subsl = nl
subsr = d.symbolic_max - nr
elif mode == 'reflect':
subsl = 2*nl - 1 - dim_l
subsr = 2*(d.symbolic_max - nr) + 1 - dim_r
else:
raise ValueError("Mode not available")
lhs.append(function.subs({d: dim_l}))
lhs.append(function.subs({d: dim_r}))
rhs.append(function.subs({d: subsl}))
rhs.append(function.subs({d: subsr}))
options.extend([None, None])

if mapper and d in mapper.keys():
exprs = mapper[d]
lhs_extra = exprs['lhs']
rhs_extra = exprs['rhs']
lhs.extend(as_list(lhs_extra))
rhs.extend(as_list(rhs_extra))
options_extra = exprs.get('options', len(as_list(lhs_extra))*[None, ])
if isinstance(options_extra, list):
options.extend(options_extra)
else:
options.extend([options_extra])

if all(options is None for i in options):
options = None
assert len(lhss) == len(rhss) == len(optionss)

assign(lhs, rhs, options=options, name=name, **kwargs)
name = name or 'initialize_%s' % '_'.join(f.name for f in functions)
assign(lhss, rhss, options=optionss, name=name, **kwargs)

if pad_halo:
pad_outhalo(function)
for f in functions:
pad_outhalo(f)
2 changes: 1 addition & 1 deletion devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _interp_idx(self, variables, implicit_dims=None):
temps.extend(self._coeff_temps(implicit_dims))

# Substitution mapper for variables
idx_subs = {v: v.subs({k: c - v.origin.get(k, 0) + p
idx_subs = {v: v.subs({k: c + p
for ((k, c), p) in zip(mapper.items(), pos)})
for v in variables}
idx_subs.update(dict(zip(self._rdim, mapper.values())))
Expand Down
11 changes: 9 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,15 @@ def indexify(self, indices=None, subs=None):
subs = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions]

# Indices after substitutions
indices = [sympy.sympify(a.subs(d, d - o).xreplace(s)) for a, d, o, s in
zip(self.args, self.dimensions, self.origin, subs)]
indices = []
for a, d, o, s in zip(self.args, self.dimensions, self.origin, subs):
if d in a.free_symbols:
# Shift by origin d -> d - o.
indices.append(sympy.sympify(a.subs(d, d - o).xreplace(s)))
else:
# Dimension has been removed, e.g. u[10], plain shift by origin
indices.append(sympy.sympify(a - o).xreplace(s))

indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
for i in indices]

Expand Down
17 changes: 17 additions & 0 deletions tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,23 @@ def test_if_halo_mpi(self, nbl):
expected = np.pad(a[na//2:, na//2:], [(0, 1+nbl), (0, 1+nbl)], 'edge')
assert np.all(f._data_with_outhalo._local == expected)

def test_batching(self):
grid = Grid(shape=(12, 12))

a = np.arange(16).reshape((4, 4))

f = Function(name='f', grid=grid, dtype=np.int32)
g = Function(name='g', grid=grid, dtype=np.float32)
h = Function(name='h', grid=grid, dtype=np.float64)

initialize_function([f, g, h], [a, a, a], 4, mode='reflect')

for i in [f, g, h]:
assert np.all(a[:, ::-1] - np.array(i.data[4:8, 0:4]) == 0)
assert np.all(a[:, ::-1] - np.array(i.data[4:8, 8:12]) == 0)
assert np.all(a[::-1, :] - np.array(i.data[0:4, 4:8]) == 0)
assert np.all(a[::-1, :] - np.array(i.data[8:12, 4:8]) == 0)


class TestBuiltinsResult(object):

Expand Down
11 changes: 11 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ def test_indexed():
assert ub.indexed.free_symbols == {ub.indexed}


def test_indexed_staggered():
grid = Grid(shape=(10, 10))
x, y = grid.dimensions
hx, hy = x.spacing, y.spacing

u = Function(name='u', grid=grid, staggered=(x, y))
u0 = u.subs({x: 1, y: 2})
assert u0.indices == (1 + hx / 2, 2 + hy / 2)
assert u0.indexify().indices == (1, 2)


def test_bundle():
grid = Grid(shape=(4, 4))

Expand Down
Loading