From 2dd25ba2fd7305d7ba53a56c03ea7989773a6856 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 15 Sep 2023 09:19:22 -0400 Subject: [PATCH] api: process injected expression dimensions in case it's not the sparse function --- devito/builtins/initializers.py | 5 ++++- devito/operations/interpolators.py | 22 +++++++++++++++------- devito/types/dimension.py | 8 ++------ tests/test_interpolation.py | 25 +++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 14 deletions(-) diff --git a/devito/builtins/initializers.py b/devito/builtins/initializers.py index f338e194e1e..7716e43f9a7 100644 --- a/devito/builtins/initializers.py +++ b/devito/builtins/initializers.py @@ -77,7 +77,10 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs): symbolic_max=d.symbolic_max + h.right) eqs = [eq.xreplace(subs) for eq in eqs] - dv.Operator(eqs, name=name, **kwargs)() + if f.is_TimeFunction: + dv.Operator(eqs, name=name, **kwargs)(time_M=f.shape[f._time_position]) + else: + dv.Operator(eqs, name=name, **kwargs)() def smooth(f, g, axis=None): diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index 92bc3923926..12fe31cae69 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -169,11 +169,16 @@ def _rdim(self): return DimensionTuple(*rdims, getters=self._gdims) - def _augment_implicit_dims(self, implicit_dims): + def _augment_implicit_dims(self, implicit_dims, extras=None): + if extras is not None: + extra = tuple([i for v in extras for i in v.dimensions]) + else: + extra = tuple() + if self.sfunction._sparse_position == -1: - return self.sfunction.dimensions + as_tuple(implicit_dims) + return self.sfunction.dimensions + as_tuple(implicit_dims) + extra else: - return as_tuple(implicit_dims) + self.sfunction.dimensions + return as_tuple(implicit_dims) + self.sfunction.dimensions + extra def _coeff_temps(self, implicit_dims): return [] @@ -252,8 +257,6 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None): interpolation expression, but that should be honored when constructing the operator. """ - implicit_dims = self._augment_implicit_dims(implicit_dims) - # Derivatives must be evaluated before the introduction of indirect accesses try: _expr = expr.evaluate @@ -263,6 +266,9 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None): variables = list(retrieve_function_carriers(_expr)) + # Implicit dimensions + implicit_dims = self._augment_implicit_dims(implicit_dims) + # List of indirection indices for all adjacent grid points idx_subs, temps = self._interp_idx(variables, implicit_dims=implicit_dims) @@ -295,8 +301,6 @@ def _inject(self, field, expr, implicit_dims=None): injection expression, but that should be honored when constructing the operator. """ - implicit_dims = self._augment_implicit_dims(implicit_dims) - # Make iterable to support inject((u, v), expr=expr) # or inject((u, v), expr=(expr1, expr2)) fields, exprs = as_tuple(field), as_tuple(expr) @@ -315,6 +319,10 @@ def _inject(self, field, expr, implicit_dims=None): _exprs = exprs variables = list(v for e in _exprs for v in retrieve_function_carriers(e)) + + # Implicit dimensions + implicit_dims = self._augment_implicit_dims(implicit_dims, variables) + variables = variables + list(fields) # List of indirection indices for all adjacent grid points diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 6044f014690..831b1332848 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -983,8 +983,7 @@ def bound_symbols(self): return set(self.parent.bound_symbols) def _arg_defaults(self, alias=None, **kwargs): - dim = alias or self - return {dim.parent.size_name: range(self.symbolic_size, np.iinfo(np.int64).max)} + return {} def _arg_values(self, *args, **kwargs): return {} @@ -1466,10 +1465,7 @@ def _arg_defaults(self, _min=None, size=None, **kwargs): A SteppingDimension does not know its max point and therefore does not have a size argument. """ - args = {self.parent.min_name: _min} - if size: - args[self.parent.size_name] = range(size-1, np.iinfo(np.int32).max) - return args + return {self.parent.min_name: _min} def _arg_values(self, *args, **kwargs): """ diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 3a22ca1db73..79e816a01b6 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -4,6 +4,7 @@ import pytest from sympy import Float +from conftest import assert_structure from devito import (Grid, Operator, Dimension, SparseFunction, SparseTimeFunction, Function, TimeFunction, DefaultDimension, Eq, PrecomputedSparseFunction, PrecomputedSparseTimeFunction, @@ -734,3 +735,27 @@ class SparseFirst(SparseFunction): op(time_M=10) expected = 10*11/2 # n (n+1)/2 assert np.allclose(s.data, expected) + + +def test_inject_function(): + nt = 11 + + grid = Grid(shape=(5, 5)) + u = TimeFunction(name="u", grid=grid, time_order=2) + src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1, + coordinates=[[0.5, 0.5]]) + + nfreq = 5 + freq_dim = DefaultDimension(name="freq", default_value=nfreq) + omega = Function(name="omega", dimensions=(freq_dim,), shape=(nfreq,), grid=grid) + omega.data.fill(1.) + + inj = src.inject(field=u.forward, expr=omega) + + op = Operator([inj]) + + assert_structure(op, ['p_src', 't', 't,p_src,freq', 't,p_src,freq,rsrcx,rsrcy'], + 'p_src,t,p_src,freq,rsrcx,rsrcy') + + op(time_M=0) + assert u.data[1, 2, 2] == nfreq