diff --git a/devito/passes/iet/linearization.py b/devito/passes/iet/linearization.py index 71015cc22e..109d19ef29 100644 --- a/devito/passes/iet/linearization.py +++ b/devito/passes/iet/linearization.py @@ -71,8 +71,12 @@ def key1(f, d): """ if f.is_regular: # For paddable objects the following holds: - # `same dim + same halo => same (auto-)padding` - return (d, f._size_halo[d], f.is_autopaddable) + # `same dim + same halo + same dtype => same (auto-)padding` + # Bundle need the actual function dtype + if f.is_Bundle: + return (d, f._size_halo[d], f.is_autopaddable, f.c0.dtype) + else: + return (d, f._size_halo[d], f.is_autopaddable, f.dtype) else: return False diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index fb8dc4c59d..2a25ef5c12 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -39,6 +39,10 @@ def dtype(self): def compiler(self): return self._settings['compiler'] + def single_prec(self, expr=None): + dtype = sympy_dtype(expr) if expr is not None else self.dtype + return dtype in [np.float32, np.float16] + def parenthesize(self, item, level, strict=False): if isinstance(item, BooleanFunction): return "(%s)" % self._print(item) @@ -104,9 +108,8 @@ def _print_math_func(self, expr, nest=False, known=None): except KeyError: return super()._print_math_func(expr, nest=nest, known=known) - dtype = sympy_dtype(expr) - if dtype is np.float32: - cname += 'f' + if self.single_prec(expr): + cname = '%sf' % cname args = ', '.join((self._print(arg) for arg in expr.args)) @@ -116,7 +119,7 @@ def _print_Pow(self, expr): # Need to override because of issue #1627 # E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x try: - if expr.exp == -1 and self.dtype == np.float32: + if expr.exp == -1 and self.single_prec(): PREC = precedence(expr) return '1.0F/%s' % self.parenthesize(expr.base, PREC) except AttributeError: @@ -196,8 +199,8 @@ def _print_Float(self, expr): elif rv.startswith('.0'): rv = '0.' + rv[2:] - if self.dtype == np.float32: - rv = rv + 'F' + if self.single_prec(): + rv = '%sF' % rv return rv @@ -252,8 +255,8 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) - if self.dtype == np.float32: - func_name += 'f' + if self.single_prec(): + func_name = '%sf' % func_name return '%s(%s)' % (func_name, self._print(*expr.args)) def _print_DefFunction(self, expr): diff --git a/devito/types/sparse.py b/devito/types/sparse.py index c92550ed9d..01996afd1f 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -1019,6 +1019,14 @@ def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None): return super().inject(field, expr, implicit_dims=implicit_dims) + @property + def forward(self): + """Symbol for the time-forward state of the TimeFunction.""" + i = int(self.time_order / 2) if self.time_order >= 2 else 1 + _t = self.dimensions[self._time_position] + + return self._subs(_t, _t + i * _t.spacing) + class PrecomputedSparseFunction(AbstractSparseFunction): """ diff --git a/docker/Dockerfile.intel b/docker/Dockerfile.intel index 73db50f7ea..2e2888071f 100644 --- a/docker/Dockerfile.intel +++ b/docker/Dockerfile.intel @@ -60,6 +60,7 @@ RUN apt-get update -y && apt-get dist-upgrade -y && \ libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev ENV MPI4PY_FLAGS='. /opt/intel/oneapi/setvars.sh intel64 && ' +ENV MPI4PY_RC_RECV_MPROBE=0 ############################################################## # ICC image diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 88f533687d..a92ddff484 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -592,3 +592,23 @@ def test_inc_w_default_dims(): assert f.shape[0]*k._default_value == 35 assert np.all(g.data[3] == f.shape[0]*k._default_value) assert np.all(g.data[4:] == 0) + + +def test_different_dtype(): + space_order = 4 + + grid = Grid(shape=(4, 4)) + + f = Function(name='f', grid=grid, space_order=space_order) + b = Function(name='b', grid=grid, space_order=space_order, dtype=np.float64) + + f.data[:] = 2.1 + b.data[:] = 1.3 + + eq = Eq(f, b.dx + f.dy) + + op1 = Operator(eq, opt=('advanced', {'linearize': True})) + + # Check generated code has different strides for different dtypes + assert "bL0(x,y) b[(x)*y_stride0 + (y)]" in str(op1) + assert "L0(x,y) f[(x)*y_stride1 + (y)]" in str(op1)