Skip to content

Commit

Permalink
compiler: Refactor _generate_macros
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Jun 6, 2024
1 parent 6285538 commit 2e5d1eb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
27 changes: 17 additions & 10 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def _generate_macros(iet, tracker=None, **kwargs):
headers = sorted((ccode(define), ccode(expr)) for define, expr in headers)

# Generate Macros from higher-level SymPy objects
for i in FindApplications().visit(iet):
headers.extend(_generate_macros_math(i))
headers.extend(_generate_macros_math(iet))

# Remove redundancies while preserving the order
headers = filter_ordered(headers)
Expand All @@ -170,12 +169,12 @@ def _generate_macros(iet, tracker=None, **kwargs):

def _generate_macros_findexeds(iet, sregistry=None, tracker=None, **kwargs):
indexeds = FindSymbols('indexeds').visit(iet)
indexeds = [i for i in indexeds if isinstance(i, FIndexed)]
if not indexeds:
findexeds = [i for i in indexeds if isinstance(i, FIndexed)]
if not findexeds:
return iet

subs = {}
for i in indexeds:
for i in findexeds:
try:
v = tracker[i.base].v
subs[i] = v.func(v.base, *i.indices)
Expand All @@ -194,22 +193,30 @@ def _generate_macros_findexeds(iet, sregistry=None, tracker=None, **kwargs):
return iet


def _generate_macros_math(iet):
headers = []
for i in FindApplications().visit(iet):
headers.extend(_lower_macro_math(i))

return headers


@singledispatch
def _generate_macros_math(expr):
def _lower_macro_math(expr):
return ()


@_generate_macros_math.register(Min)
@_generate_macros_math.register(sympy.Min)
@_lower_macro_math.register(Min)
@_lower_macro_math.register(sympy.Min)
def _(expr):
if has_integer_args(*expr.args) and len(expr.args) == 2:
return (('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),)
else:
return ()


@_generate_macros_math.register(Max)
@_generate_macros_math.register(sympy.Max)
@_lower_macro_math.register(Max)
@_lower_macro_math.register(sympy.Max)
def _(expr):
if has_integer_args(*expr.args) and len(expr.args) == 2:
return (('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))')),)
Expand Down
2 changes: 1 addition & 1 deletion devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,7 @@ class VirtualDimension(CustomDimension):
def __init_finalize__(self, name, parent=None):
super().__init_finalize__(name, parent=parent,
symbolic_min=sympy.S.Zero,
symbolic_max=sympy.S.One)
symbolic_max=sympy.S.Zero)


# ***
Expand Down

0 comments on commit 2e5d1eb

Please sign in to comment.