Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: fix printer dtype processing #2388

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions devito/passes/iet/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ def key1(f, d):
"""
if f.is_regular:
# For paddable objects the following holds:
# `same dim + same halo => same (auto-)padding`
return (d, f._size_halo[d], f.is_autopaddable)
# `same dim + same halo + same dtype => same (auto-)padding`
# Bundle need the actual function dtype
if f.is_Bundle:
return (d, f._size_halo[d], f.is_autopaddable, f.c0.dtype)
else:
return (d, f._size_halo[d], f.is_autopaddable, f.dtype)
else:
return False

Expand Down
19 changes: 11 additions & 8 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def dtype(self):
def compiler(self):
return self._settings['compiler']

def single_prec(self, expr=None):
dtype = sympy_dtype(expr) if expr is not None else self.dtype
return dtype in [np.float32, np.float16]

def parenthesize(self, item, level, strict=False):
if isinstance(item, BooleanFunction):
return "(%s)" % self._print(item)
Expand Down Expand Up @@ -104,9 +108,8 @@ def _print_math_func(self, expr, nest=False, known=None):
except KeyError:
return super()._print_math_func(expr, nest=nest, known=known)

dtype = sympy_dtype(expr)
if dtype is np.float32:
cname += 'f'
if self.single_prec(expr):
cname = '%sf' % cname

args = ', '.join((self._print(arg) for arg in expr.args))

Expand All @@ -116,7 +119,7 @@ def _print_Pow(self, expr):
# Need to override because of issue #1627
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
try:
if expr.exp == -1 and self.dtype == np.float32:
if expr.exp == -1 and self.single_prec():
PREC = precedence(expr)
return '1.0F/%s' % self.parenthesize(expr.base, PREC)
except AttributeError:
Expand Down Expand Up @@ -196,8 +199,8 @@ def _print_Float(self, expr):
elif rv.startswith('.0'):
rv = '0.' + rv[2:]

if self.dtype == np.float32:
rv = rv + 'F'
if self.single_prec():
rv = '%sF' % rv

return rv

Expand Down Expand Up @@ -252,8 +255,8 @@ def _print_ComponentAccess(self, expr):

def _print_TrigonometricFunction(self, expr):
func_name = str(expr.func)
if self.dtype == np.float32:
func_name += 'f'
if self.single_prec():
func_name = '%sf' % func_name
return '%s(%s)' % (func_name, self._print(*expr.args))

def _print_DefFunction(self, expr):
Expand Down
8 changes: 8 additions & 0 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,14 @@ def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):

return super().inject(field, expr, implicit_dims=implicit_dims)

@property
def forward(self):
"""Symbol for the time-forward state of the TimeFunction."""
i = int(self.time_order / 2) if self.time_order >= 2 else 1
_t = self.dimensions[self._time_position]

return self._subs(_t, _t + i * _t.spacing)


class PrecomputedSparseFunction(AbstractSparseFunction):
"""
Expand Down
1 change: 1 addition & 0 deletions docker/Dockerfile.intel
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ RUN apt-get update -y && apt-get dist-upgrade -y && \
libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev

ENV MPI4PY_FLAGS='. /opt/intel/oneapi/setvars.sh intel64 && '
ENV MPI4PY_RC_RECV_MPROBE=0

##############################################################
# ICC image
Expand Down
20 changes: 20 additions & 0 deletions tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,23 @@ def test_inc_w_default_dims():
assert f.shape[0]*k._default_value == 35
assert np.all(g.data[3] == f.shape[0]*k._default_value)
assert np.all(g.data[4:] == 0)


def test_different_dtype():
space_order = 4

grid = Grid(shape=(4, 4))

f = Function(name='f', grid=grid, space_order=space_order)
b = Function(name='b', grid=grid, space_order=space_order, dtype=np.float64)

f.data[:] = 2.1
b.data[:] = 1.3

eq = Eq(f, b.dx + f.dy)

op1 = Operator(eq, opt=('advanced', {'linearize': True}))

# Check generated code has different strides for different dtypes
assert "bL0(x,y) b[(x)*y_stride0 + (y)]" in str(op1)
assert "L0(x,y) f[(x)*y_stride1 + (y)]" in str(op1)
Loading