Skip to content

Commit

Permalink
compiler: fix stride generation with mied dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jun 18, 2024
1 parent 005cbc4 commit d3192eb
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
8 changes: 6 additions & 2 deletions devito/passes/iet/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ def key1(f, d):
"""
if f.is_regular:
# For paddable objects the following holds:
# `same dim + same halo => same (auto-)padding`
return (d, f._size_halo[d], f.is_autopaddable)
# `same dim + same halo + same dtype => same (auto-)padding`
# Bundle need the actual function dtype
if f.is_Bundle:
return (d, f._size_halo[d], f.is_autopaddable, f.c0.dtype)
else:
return (d, f._size_halo[d], f.is_autopaddable, f.dtype)
else:
return False

Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def compiler(self):
return self._settings['compiler']

def single_prec(self, expr=None):
dtype = sympy_dtype(expr or self)
dtype = sympy_dtype(expr) if expr is not None else self.dtype
return dtype in [np.float32, np.float16]

def parenthesize(self, item, level, strict=False):
Expand Down
1 change: 1 addition & 0 deletions docker/Dockerfile.intel
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ RUN apt-get update -y && apt-get dist-upgrade -y && \
libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev

ENV MPI4PY_FLAGS='. /opt/intel/oneapi/setvars.sh intel64 && '
ENV MPI4PY_RC_RECV_MPROBE=0

##############################################################
# ICC image
Expand Down
20 changes: 20 additions & 0 deletions tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,23 @@ def test_inc_w_default_dims():
assert f.shape[0]*k._default_value == 35
assert np.all(g.data[3] == f.shape[0]*k._default_value)
assert np.all(g.data[4:] == 0)


def test_different_dtype():
space_order = 4

grid = Grid(shape=(4, 4))

f = Function(name='f', grid=grid, space_order=space_order)
b = Function(name='b', grid=grid, space_order=space_order, dtype=np.float64)

f.data[:] = 2.1
b.data[:] = 1.3

eq = Eq(f, b.dx + f.dy)

op1 = Operator(eq, opt=('advanced', {'linearize': True}))

# Check generated code has different strides for different dtypes
assert "bL0(x,y) b[(x)*y_stride0 + (y)]" in str(op1)
assert "L0(x,y) f[(x)*y_stride1 + (y)]" in str(op1)

0 comments on commit d3192eb

Please sign in to comment.