Skip to content

Commit

Permalink
api: Extend par-tile opt-option
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Feb 10, 2022
1 parent 108814f commit c1d4f29
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
47 changes: 43 additions & 4 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Iterable

from devito.core.autotuning import autotune
from devito.exceptions import InvalidOperator
from devito.logger import warning
Expand Down Expand Up @@ -219,6 +221,15 @@ class OptOption(object):
pass


class ParTileArg(tuple):

def __new__(cls, items, shm=0, tag=None):
obj = super().__new__(cls, items)
obj.shm = shm
obj.tag = tag
return obj


class ParTile(tuple, OptOption):

def __new__(cls, items, default=None):
Expand All @@ -229,11 +240,39 @@ def __new__(cls, items, default=None):
raise ValueError("Expected `default` value, got None")
items = (as_tuple(default),)
elif isinstance(items, tuple):
# Normalize to tuple of tuples
if is_integer(items[0]):
items = (items,)
if not items:
raise ValueError("Expected at least one value")

# Normalize to tuple of ParTileArgs

x = items[0]
if is_integer(x):
# E.g., (32, 4, 8)
items = (ParTileArg(items),)

elif isinstance(x, Iterable):
if not x:
raise ValueError("Expected at least one value")

try:
y = items[1]
if is_integer(y):
# E.g., ((32, 4, 8), 1)
# E.g., ((32, 4, 8), 1, 'tag')
items = (ParTileArg(*items),)
else:
try:
# E.g., (((32, 4, 8), 1), ((32, 4, 4), 2))
# E.g., (((32, 4, 8), 1, 'tag0'), ((32, 4, 4), 2, 'tag1'))
items = tuple(ParTileArg(*i) for i in items)
except TypeError:
# E.g., ((32, 4, 8), (32, 4, 4))
items = tuple(ParTileArg(i) for i in items)
except IndexError:
# E.g., ((32, 4, 8),)
items = (ParTileArg(x),)
else:
items = tuple(tuple(i) for i in items)
raise ValueError("Expected int or tuple, got %s instead" % type(x))
else:
raise ValueError("Expected bool or tuple, got %s instead" % type(items))

Expand Down
3 changes: 3 additions & 0 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def test_cache_blocking_structure_optrelax():
(True, ((16, 16, 16), (16, 16, 16))),
((32, 4, 4), ((32, 4, 4), (32, 4, 4))),
(((16, 4), (16,)), ((16, 4, 4), (16, 16, 16))),
(((32, 4, 4), 1), ((32, 4, 4), (32, 4, 4))),
(((32, 4, 4), 1, 'tag0'), ((32, 4, 4), (32, 4, 4))),
((((32, 4, 4), 1, 'tag0'), ((32, 4, 4), 2)), ((32, 4, 4), (32, 4, 4))),
])
def test_cache_blocking_structure_optpartile(par_tile, expected):
grid = Grid(shape=(8, 8, 8))
Expand Down

0 comments on commit c1d4f29

Please sign in to comment.