From 47594735c41e8909e29c361e2b811390e4be5352 Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 24 Jul 2023 08:35:26 -0400 Subject: [PATCH] compiler: prevent Eq dims to be lost if only implicit --- devito/ir/clusters/cluster.py | 3 ++- tests/test_dse.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 857e62f7c6..0dc3200b4f 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -172,7 +172,8 @@ def used_dimensions(self): example, reduction or redundant (i.e., invariant) Dimensions won't appear in an expression. """ - return {i for i in self.free_symbols if i.is_Dimension} + idims = set.union(*[set(e.implicit_dims) for e in self.exprs]) + return {i for i in self.free_symbols if i.is_Dimension} | idims @cached_property def scope(self): diff --git a/tests/test_dse.py b/tests/test_dse.py index eee0027896..b1ed9394ec 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -418,6 +418,19 @@ def test_contracted(self, exprs, expected, visit): for j in trees] == expected assert "".join(mapper.get(i.dim.name, i.dim.name) for i in iters) == visit + def test_implicit_only(self): + grid = Grid(shape=(5, 5)) + time = grid.time_dim + u = TimeFunction(name="u", grid=grid, time_order=1) + idimeq = Eq(Symbol('s'), 1, implicit_dims=time) + + op = Operator([Eq(u.forward, u + 1.), idimeq]) + trees = retrieve_iteration_tree(op) + + assert len(trees) == 2 + assert_structure(op, ['t,x,y', 't'], 'txy') + assert trees[1].dimensions == [time] + class TestAliases(object):