Skip to content

Commit

Permalink
Merge pull request #283 from opesci/fix-iteration-direction
Browse files Browse the repository at this point in the history
Fix iteration direction
  • Loading branch information
navjotk authored Jul 5, 2017
2 parents ea614e0 + c79fd2e commit d044efc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
5 changes: 4 additions & 1 deletion devito/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def __init__(self, nodes, dimension, limits, index=None, offsets=None,

self.dim = dimension
self.index = index or self.dim.name
# Store direction, as it might change on the dimension
# before we use it during code generation.
self.reverse = self.dim.reverse

# Generate loop limits
if isinstance(limits, Iterable):
Expand Down Expand Up @@ -346,7 +349,7 @@ def ccode(self):
end = self.limits[1]

# For reverse dimensions flip loop bounds
if self.dim.reverse:
if self.reverse:
loop_init = c.InlineInitializer(c.Value("int", self.index),
ccode('%s - 1' % end))
loop_cond = '%s >= %s' % (self.index, ccode(start))
Expand Down
29 changes: 22 additions & 7 deletions tests/test_timestepping.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,14 @@ def d(shape=(11, 11)):

def test_forward(a):
a.data[0, :] = 1.
eqn = Eq(a.forward, a + 1.)
Operator(eqn, dle=None, dse=None)()
Operator(Eq(a.forward, a + 1.))()
for i in range(a.shape[0]):
assert np.allclose(a.data[i, :], 1. + i, rtol=1.e-12)


def test_backward(b):
b.data[-1, :] = 7.
eqn = Eq(b.backward, b - 1.)
Operator(eqn, dle=None, dse=None, time_axis=Backward)()
Operator(Eq(b.backward, b - 1.), time_axis=Backward)()
for i in range(b.shape[0]):
assert np.allclose(b.data[i, :], 2. + i, rtol=1.e-12)

Expand All @@ -57,19 +55,36 @@ def test_forward_unroll(a, c, nt=5):
c.data[0, :] = 1.
eqn_c = Eq(c.forward, c + 1.)
eqn_a = Eq(a.forward, c.forward)
Operator([eqn_c, eqn_a], dle=None, dse=None)(time=nt)
Operator([eqn_c, eqn_a])(time=nt)
for i in range(nt):
assert np.allclose(a.data[i, :], 1. + i, rtol=1.e-12)


def test_forward_backward(a, b, nt=5):
"""Test a forward operator followed by a backward marching one"""
a.data[0, :] = 1.
b.data[0, :] = 1.
eqn_a = Eq(a.forward, a + 1.)
Operator(eqn_a, dle=None, dse=None, time_axis=Forward)(time=nt)
Operator(eqn_a, time_axis=Forward)(time=nt)

eqn_b = Eq(b, a + 1.)
Operator(eqn_b, dle=None, dse=None, time_axis=Backward)(time=nt)
Operator(eqn_b, time_axis=Backward)(time=nt)
for i in range(nt):
assert np.allclose(b.data[i, :], 2. + i, rtol=1.e-12)


def test_forward_backward_overlapping(a, b, nt=5):
"""
Test a forward operator followed by a backward one, but with
overlapping operator definitions.
"""
a.data[0, :] = 1.
b.data[0, :] = 1.
op_fwd = Operator(Eq(a.forward, a + 1.), time_axis=Forward)
op_bwd = Operator(Eq(b, a + 1.), time_axis=Backward)

op_fwd(time=nt)
op_bwd(time=nt)
for i in range(nt):
assert np.allclose(b.data[i, :], 2. + i, rtol=1.e-12)

Expand Down

0 comments on commit d044efc

Please sign in to comment.