Skip to content

Commit

Permalink
Merge pull request #2335 from devitocodes/tweak-error-handling
Browse files Browse the repository at this point in the history
compiler: Tweak check_stability to ensure cleanup is performed
  • Loading branch information
mloubout committed Mar 27, 2024
2 parents 7b7b1eb + adc0389 commit 497eb50
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 22 deletions.
17 changes: 12 additions & 5 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
__all__ = ['Node', 'MultiTraversable', 'Block', 'Expression', 'Callable',
'Call', 'ExprStmt', 'Conditional', 'Iteration', 'List', 'Section',
'TimedList', 'Prodder', 'MetaCall', 'PointerCast', 'HaloSpot',
'Definition', 'ExpressionBundle', 'AugmentedExpression',
'Definition', 'ExpressionBundle', 'AugmentedExpression', 'Break',
'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration',
'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma',
'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace',
Expand Down Expand Up @@ -432,8 +432,8 @@ def is_initializable(self):
"""
True if it can be an initializing assignment, False otherwise.
"""
return ((self.is_scalar and not self.is_reduction) or
(self.is_tensor and isinstance(self.expr.rhs, ListInitializer)))
return (((self.is_scalar and not self.is_reduction) or
(self.is_tensor and isinstance(self.expr.rhs, ListInitializer))))

@property
def defines(self):
Expand Down Expand Up @@ -796,17 +796,19 @@ class CallableBody(MultiTraversable):
Data deallocations for `body`.
errors : list of Nodes, optional
Error handling for `body`.
retstmt : Node, optional
The return statement for `body`.
"""

is_CallableBody = True

_traversable = ['unpacks', 'init', 'standalones', 'allocs', 'stacks',
'casts', 'bundles', 'maps', 'strides', 'objs', 'body',
'unmaps', 'unbundles', 'frees', 'errors']
'unmaps', 'unbundles', 'frees', 'errors', 'retstmt']

def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(),
allocs=(), stacks=(), casts=(), bundles=(), objs=(), maps=(),
unmaps=(), unbundles=(), frees=(), errors=()):
unmaps=(), unbundles=(), frees=(), errors=(), retstmt=()):
# Sanity check
assert not isinstance(body, CallableBody), "CallableBody's cannot be nested"

Expand All @@ -826,6 +828,7 @@ def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(),
self.unbundles = as_tuple(unbundles)
self.frees = as_tuple(frees)
self.errors = as_tuple(errors)
self.retstmt = as_tuple(retstmt)

def __repr__(self):
return ("<CallableBody <unpacks=%d, allocs=%d, casts=%d, maps=%d, "
Expand Down Expand Up @@ -1385,6 +1388,10 @@ def __repr__(self):
return ""


class Break(Node):
pass


class Return(Node):

def __init__(self, value=None):
Expand Down
11 changes: 10 additions & 1 deletion devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,9 @@ def visit_Section(self, o):
body = flatten(self._visit(i) for i in o.children)
return c.Module(body)

def visit_Break(self, o):
return c.Statement('break')

def visit_Return(self, o):
v = 'return'
if o.value is not None:
Expand Down Expand Up @@ -679,7 +682,13 @@ def visit_Operator(self, o, mode='all'):
# Kernel signature and body
body = flatten(self._visit(i) for i in o.children)
signature = self._gen_signature(o)
retval = [c.Line(), c.Statement("return 0")]

# Honor the `retstmt` flag if set
if o.body.retstmt:
retval = []
else:
retval = [c.Line(), c.Statement("return 0")]

kernel = c.FunctionBody(signature, c.Block(body + retval))

# Elemental functions
Expand Down
34 changes: 32 additions & 2 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
from functools import wraps

import sympy
import numpy as np
from cached_property import cached_property

from devito.finite_differences.differentiable import Mul
from devito.finite_differences.elementary import floor
from devito.finite_differences.elementary import floor, sqrt, besseli, sinc
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
from devito.tools import as_tuple, flatten, filter_ordered
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
CustomDimension)
from devito.types.utils import DimensionTuple

__all__ = ['LinearInterpolator', 'PrecomputedInterpolator']
__all__ = ['LinearInterpolator', 'PrecomputedInterpolator', 'SincInterpolator']


def check_radius(func):
Expand Down Expand Up @@ -421,3 +422,32 @@ def _weights(self):
for (ri, rd) in enumerate(self._rdim)]
return Mul(*[self.interpolation_coeffs.subs(mapper)
for mapper in mappers])


class SincInterpolator(LinearInterpolator):
"""
Hicks windowed sinc interpolation scheme.
Arbitrary source and receiver positioning in finite‐difference schemes
using Kaiser windowed sinc functions
https://library.seg.org/doi/10.1190/1.1451454
"""

# Table 1
_b_table = {1: 0.0, 2: 1.84, 3: 3.04,
4: 4.14, 5: 5.26, 6: 6.40,
7: 7.51, 8: 8.56, 9: 9.56, 10: 10.64}

@property
def _weights(self):
b = self._b_table[self.r]
b0 = besseli(0, b).evalf()
W = sympy.S.One
for (rd, pos) in zip(self._rdim, self._point_symbols):
rpos = rd - pos
Wd = besseli(0, b*sqrt(1 - (rpos/self.r)**2))/b0
S = sinc(sympy.pi*rpos)
W *= Wd*S
return W
34 changes: 29 additions & 5 deletions devito/passes/iet/errors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import cgen as c
import numpy as np
from sympy import Not
from sympy import Expr, Not, S

from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List,
Return, FindNodes, FindSymbols, Transformer,
from devito.ir.iet import (Call, Conditional, DummyExpr, EntryFunction, Iteration,
List, Break, Return, FindNodes, FindSymbols, Transformer,
make_callable)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import CondEq, DefFunction
from devito.types import Eq, Inc, Symbol
from devito.tools import dtype_to_ctype
from devito.types import Eq, Inc, LocalObject, Symbol

__all__ = ['check_stability', 'error_mapper']

Expand Down Expand Up @@ -69,18 +70,41 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
name = sregistry.make_name(prefix='check')
check = Symbol(name=name, dtype=np.int32)

retval = Retval(name='retval')

errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[
Call(efunc.name, efunc.parameters, retobj=check),
Conditional(Not(check), Return(error_mapper['Stability']))
Conditional(Not(check), List(body=[
DummyExpr(retval, error_mapper['Stability']),
Break()
]))
]))
errctl = List(header=c.Comment("Stability check"), body=[errctl])
mapper[n] = n._rebuild(nodes=n.nodes + (errctl,))

# One check is enough
break
else:
return iet, {}

iet = Transformer(mapper).visit(iet)

# We now must return a suitable error code
body = iet.body._rebuild(
body=(DummyExpr(retval, 0, init=True),) + iet.body.body,
retstmt=Return(retval)
)
iet = iet._rebuild(body=body)

return iet, {'efuncs': efuncs, 'includes': includes}


class Retval(LocalObject, Expr):

dtype = dtype_to_ctype(np.int32)
default_initvalue = S.Zero


error_mapper = {
'Stability': 100,
'KernelLaunch': 200,
Expand Down
8 changes: 7 additions & 1 deletion devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import sympy

from devito.finite_differences import Max, Min
from devito.finite_differences import Max, Min, sinc
from devito.ir import (Any, Forward, Iteration, List, Prodder, FindApplications,
FindNodes, FindSymbols, Transformer, Uxreplace,
filter_iterations, retrieve_iteration_tree, pull_dims)
Expand Down Expand Up @@ -176,6 +176,12 @@ def _(expr):
return set()


@_generate_macros.register(sinc)
@_generate_macros.register(sympy.sinc)
def _(expr):
return {('sinc(a)', ('(((a) == 0) ? (1) : (sin((a))/(a)))'))}


@iet_pass
def minimize_symbols(iet):
"""
Expand Down
23 changes: 23 additions & 0 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
from sympy.printing.c import C99CodePrinter

from devito import configuration
from devito.arch.compiler import AOMPCompiler
from devito.arch.archinfo import AppleArm
from devito.symbolics.extended_sympy import MathFunction
from devito.symbolics.inspection import has_integer_args
from devito.types.basic import AbstractFunction

Expand Down Expand Up @@ -241,6 +244,26 @@ def _print_DefFunction(self, expr):
template = ''
return "%s%s(%s)" % (expr.name, template, ','.join(arguments))

def _print_besselK(self, expr, kind):
# Most platform support float version, arm doesn't (only checked OSX)
# TODO: Check graviton
if self.dtype == np.float32 and \
not isinstance(configuration['platform'], AppleArm):
ext = 'f'
else:
ext = ''
# Get order and args
if expr.order == 0:
name = "cyl_bessel_%s0%s" % (kind, ext)
args = expr.args[1]
elif expr.order == 1:
name = "cyl_bessel_%s1%s" % (kind, ext)
args = expr.args[1]
else:
name = "cyl_bessel_%sn%s" % (kind, ext)
args = ", ".join([self._print(i) for i in expr.args])
return self._print(MathFunction(name, args))

_print_MathFunction = _print_DefFunction

def _print_Fallback(self, expr):
Expand Down
15 changes: 11 additions & 4 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from devito.finite_differences import generate_fd_shortcuts
from devito.mpi import MPI, SparseDistributor
from devito.operations import LinearInterpolator, PrecomputedInterpolator
from devito.operations import (LinearInterpolator, PrecomputedInterpolator,
SincInterpolator)
from devito.symbolics import indexify, retrieve_function_carriers
from devito.tools import (ReducerMap, as_tuple, flatten, prod, filter_ordered,
is_integer, dtype_to_mpidtype)
Expand All @@ -29,6 +30,10 @@
'PrecomputedSparseTimeFunction', 'MatrixSparseTimeFunction']


_interpolators = {'linear': LinearInterpolator, 'sinc': SincInterpolator}
_default_radius = {'linear': 1, 'sinc': 2}


class AbstractSparseFunction(DiscreteFunction):

"""
Expand Down Expand Up @@ -799,16 +804,18 @@ class SparseFunction(AbstractSparseFunction):

is_SparseFunction = True

_radius = 1
"""The radius of the stencil operators provided by the SparseFunction."""

_sub_functions = ('coordinates',)

__rkwargs__ = AbstractSparseFunction.__rkwargs__ + ('coordinates',)
__rkwargs__ = AbstractSparseFunction.__rkwargs__ + ('coordinates', 'interpolator')

def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)
self.interpolator = LinearInterpolator(self)

interp = kwargs.get('interpolator', 'linear')
self.interpolator = _interpolators[interp](self)
self._radius = kwargs.get('r', _default_radius[interp])

# Set up sparse point coordinates
coordinates = kwargs.get('coordinates', kwargs.get('coordinates_data'))
Expand Down
4 changes: 4 additions & 0 deletions examples/seismic/acoustic/acoustic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def run(shape=(50, 50, 50), spacing=(20.0, 20.0, 20.0), tn=1000.0,
# Define receiver geometry (spread across x, just below surface)
rec, u, summary = solver.forward(save=save, autotune=autotune)

import matplotlib.pyplot as plt
plt.imshow(rec.data, aspect='auto', cmap='seismic', vmin=-1e1, vmax=1e1)
plt.show()

if preset == 'constant-isotropic':
# With a new m as Constant
v0 = Constant(name="v", value=2.0, dtype=np.float32)
Expand Down
35 changes: 31 additions & 4 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from devito import (Grid, Operator, Dimension, SparseFunction, SparseTimeFunction,
Function, TimeFunction, DefaultDimension, Eq, switchconfig,
PrecomputedSparseFunction, PrecomputedSparseTimeFunction,
MatrixSparseTimeFunction)
MatrixSparseTimeFunction, initialize_function)
from examples.seismic import (demo_model, TimeAxis, RickerSource, Receiver,
AcquisitionGeometry)
from examples.seismic.acoustic import AcousticWaveSolver
Expand All @@ -19,7 +19,7 @@ def unit_box(name='a', shape=(11, 11), grid=None, space_order=1):
grid = grid or Grid(shape=shape)
a = Function(name=name, grid=grid, space_order=space_order)
dims = tuple([np.linspace(0., 1., d) for d in shape])
a.data[:] = np.meshgrid(*dims)[1]
initialize_function(a, np.meshgrid(*dims)[1], 0)
return a


Expand All @@ -33,11 +33,12 @@ def unit_box_time(name='a', shape=(11, 11), space_order=1):
return a


def points(grid, ranges, npoints, name='points'):
def points(grid, ranges, npoints, interpolator='linear', r=1, name='points'):
"""Create a set of sparse points from a set of coordinate
ranges for each spatial dimension.
"""
points = SparseFunction(name=name, grid=grid, npoint=npoints)
points = SparseFunction(name=name, grid=grid, npoint=npoints,
interpolator=interpolator, r=r)
for i, r in enumerate(ranges):
points.coordinates.data[:, i] = np.linspace(r[0], r[1], npoints)
return points
Expand Down Expand Up @@ -436,6 +437,32 @@ def test_inject(shape, coords, result, npoints=19):
assert np.allclose(a.data[indices], result, rtol=1.e-5)


@pytest.mark.parametrize('shape, coords, result', [
((11, 11), [(.05, .95), (.45, .45)], 1.),
((11, 11, 11), [(.05, .95), (.45, .45), (.45, .45)], 0.5)
])
@pytest.mark.parametrize('r', range(2, 11))
def test_inject_sinc(shape, coords, result, r, npoints=19):
"""Test point injection with a set of points forming a line
through the middle of the grid.
"""
a = unit_box(shape=shape, space_order=r)
a.data_with_halo.fill(0)
p = points(a.grid, ranges=coords, npoints=npoints, interpolator='sinc', r=r)

expr = p.inject(a, Float(1.))

op = Operator(expr)

op(a=a)
print(op)

indices = [slice(4, 6, 1) for _ in coords]
indices[0] = slice(1, -1, 1)
print(a.data[indices], result)
assert np.allclose(a.data[indices], result, rtol=1.e-5)


@pytest.mark.parametrize('shape, coords, nexpr, result', [
((11, 11), [(.05, .95), (.45, .45)], 1, 1.),
((11, 11), [(.05, .95), (.45, .45)], 2, 1.),
Expand Down

0 comments on commit 497eb50

Please sign in to comment.