Skip to content

Commit

Permalink
api: process injected expression dimensions in case it's not the spar…
Browse files Browse the repository at this point in the history
…se function
  • Loading branch information
mloubout committed Sep 15, 2023
1 parent d715f3f commit 8ae1f5f
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 14 deletions.
5 changes: 4 additions & 1 deletion devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
symbolic_max=d.symbolic_max + h.right)
eqs = [eq.xreplace(subs) for eq in eqs]

dv.Operator(eqs, name=name, **kwargs)()
try:
dv.Operator(eqs, name=name, **kwargs)(time_M=f.shape[f._time_position])
except AttributeError:
dv.Operator(eqs, name=name, **kwargs)()


def smooth(f, g, axis=None):
Expand Down
22 changes: 15 additions & 7 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,16 @@ def _rdim(self):

return DimensionTuple(*rdims, getters=self._gdims)

def _augment_implicit_dims(self, implicit_dims):
def _augment_implicit_dims(self, implicit_dims, extras=None):
if extras is not None:
extra = tuple([i for v in extras for i in v.dimensions])
else:
extra = tuple()

if self.sfunction._sparse_position == -1:
return self.sfunction.dimensions + as_tuple(implicit_dims)
return self.sfunction.dimensions + as_tuple(implicit_dims) + extra
else:
return as_tuple(implicit_dims) + self.sfunction.dimensions
return as_tuple(implicit_dims) + self.sfunction.dimensions + extra

def _coeff_temps(self, implicit_dims):
return []
Expand Down Expand Up @@ -252,8 +257,6 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
interpolation expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Derivatives must be evaluated before the introduction of indirect accesses
try:
_expr = expr.evaluate
Expand All @@ -263,6 +266,9 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):

variables = list(retrieve_function_carriers(_expr))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims)

# List of indirection indices for all adjacent grid points
idx_subs, temps = self._interp_idx(variables, implicit_dims=implicit_dims)

Expand Down Expand Up @@ -295,8 +301,6 @@ def _inject(self, field, expr, implicit_dims=None):
injection expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Make iterable to support inject((u, v), expr=expr)
# or inject((u, v), expr=(expr1, expr2))
fields, exprs = as_tuple(field), as_tuple(expr)
Expand All @@ -315,6 +319,10 @@ def _inject(self, field, expr, implicit_dims=None):
_exprs = exprs

variables = list(v for e in _exprs for v in retrieve_function_carriers(e))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)

variables = variables + list(fields)

# List of indirection indices for all adjacent grid points
Expand Down
8 changes: 2 additions & 6 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,8 +983,7 @@ def bound_symbols(self):
return set(self.parent.bound_symbols)

def _arg_defaults(self, alias=None, **kwargs):
dim = alias or self
return {dim.parent.size_name: range(self.symbolic_size, np.iinfo(np.int64).max)}
return {}

def _arg_values(self, *args, **kwargs):
return {}
Expand Down Expand Up @@ -1466,10 +1465,7 @@ def _arg_defaults(self, _min=None, size=None, **kwargs):
A SteppingDimension does not know its max point and therefore
does not have a size argument.
"""
args = {self.parent.min_name: _min}
if size:
args[self.parent.size_name] = range(size-1, np.iinfo(np.int32).max)
return args
return {self.parent.min_name: _min}

def _arg_values(self, *args, **kwargs):
"""
Expand Down
25 changes: 25 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from sympy import Float

from conftest import assert_structure
from devito import (Grid, Operator, Dimension, SparseFunction, SparseTimeFunction,
Function, TimeFunction, DefaultDimension, Eq,
PrecomputedSparseFunction, PrecomputedSparseTimeFunction,
Expand Down Expand Up @@ -734,3 +735,27 @@ class SparseFirst(SparseFunction):
op(time_M=10)
expected = 10*11/2 # n (n+1)/2
assert np.allclose(s.data, expected)


def test_inject_function():
nt = 11

grid = Grid(shape=(5, 5))
u = TimeFunction(name="u", grid=grid, time_order=2)
src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1,
coordinates=[[0.5, 0.5]])

nfreq = 5
freq_dim = DefaultDimension(name="freq", default_value=nfreq)
omega = Function(name="omega", dimensions=(freq_dim,), shape=(nfreq,), grid=grid)
omega.data.fill(1.)

inj = src.inject(field=u.forward, expr=omega)

op = Operator([inj])

assert_structure(op, ['p_src', 't', 't,p_src,freq', 't,p_src,freq,rsrcx,rsrcy'],
'p_src,t,p_src,freq,rsrcx,rsrcy')

op(time_M=0)
assert u.data[1, 2, 2] == nfreq

0 comments on commit 8ae1f5f

Please sign in to comment.