Skip to content

Commit

Permalink
Merge pull request #2302 from devitocodes/fix-cond-dim-parlang
Browse files Browse the repository at this point in the history
compiler: Fix space conditions with loop blocking
  • Loading branch information
FabioLuporini authored Feb 7, 2024
2 parents 0f249f2 + bb5a335 commit f625d1c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
6 changes: 4 additions & 2 deletions devito/passes/clusters/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,16 @@ def callback(self, clusters, prefix, blk_size_gen=None):
ispace = decompose(c.ispace, d, block_dims)

# Use the innermost BlockDimension in place of `d`
exprs = [uxreplace(e, {d: bd}) for e in c.exprs]
subs = {d: bd}
exprs = [uxreplace(e, subs) for e in c.exprs]
guards = {subs.get(i, i): v for i, v in c.guards.items()}

# The new Cluster properties -- TILABLE is dropped after blocking
properties = c.properties.drop(d)
properties = properties.add(block_dims, c.properties[d] - TILABLES)

processed.append(c.rebuild(exprs=exprs, ispace=ispace,
properties=properties))
guards=guards, properties=properties))
else:
processed.append(c)

Expand Down
32 changes: 32 additions & 0 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,38 @@ def test_spacial_subsampling(self, opt):
# Verify that u2[x,y]= u[2*x, 2*y]
assert np.allclose(u.data[:-1, 0::2, 0::2], u2.data[:-1, :, :])

def test_spacial_filtering(self):
grid = Grid(shape=(4, 4))
x, y = grid.dimensions

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)

g.data[:] = [[-.7, -.8, 0, .4],
[-.3, -.5, 0, .6],
[.1, .2, -.1, .8],
[.5, .7, 0, .9]]

condition = And(Ge(g, -0.5), Le(g, 0.5))
cd = ConditionalDimension(name='cd1', parent=y, condition=condition)

eqn = Eq(f, g, implicit_dims=cd)

# NOTE: bypass heuristics and enforce loop blocking. Now `cd` is defined
# along `y` but loop blocking will generate two different BlockDimensions,
# `y0_blk0` and `Y`, in place of `y`. So the compiler must be smart
# enough to stick `cd` inside `Y`
opt = ('advanced', {'blockrelax': 'device-aware'})

op = Operator(eqn, opt=opt)

op()

assert np.all(f.data == np.array([[0, 0, 0, .4],
[-.3, -.5, 0, 0],
[.1, .2, -.1, 0],
[.5, 0, 0, 0]], dtype=f.dtype))

def test_time_subsampling_fd(self):
nt = 19
grid = Grid(shape=(11, 11))
Expand Down

0 comments on commit f625d1c

Please sign in to comment.