Skip to content

Commit 1403dfd

Browse files
committed
compiler: cleanup ParTile
1 parent fb2170d commit 1403dfd

File tree

8 files changed

+54
-34
lines changed

8 files changed

+54
-34
lines changed

devito/arch/compiler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -899,4 +899,5 @@ def __lookup_cmds__(self):
899899
DEVITO_ARCH. Developers should add new compiler classes here.
900900
"""
901901
compiler_registry.update({'gcc-%s' % i: partial(GNUCompiler, suffix=i)
902-
for i in ['4.9', '5', '6', '7', '8', '9', '10', '11', '12']})
902+
for i in ['4.9', '5', '6', '7', '8', '9', '10',
903+
'11', '12', '13']})

devito/core/operator.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from devito.mpi.routines import mpi_registry
77
from devito.parameters import configuration
88
from devito.operator import Operator
9-
from devito.tools import as_tuple, is_integer, timed_pass, UnboundTuple
9+
from devito.tools import (as_tuple, is_integer, timed_pass,
10+
UnboundTuple, UnboundedMultiTuple)
1011
from devito.types import NThreads
1112

1213
__all__ = ['CoreOperator', 'CustomOperator',
@@ -338,11 +339,11 @@ def __new__(cls, items, rule=None, tag=None):
338339
return obj
339340

340341

341-
class ParTile(tuple, OptOption):
342+
class ParTile(UnboundedMultiTuple, OptOption):
342343

343344
def __new__(cls, items, default=None):
344345
if not items:
345-
return tuple()
346+
return UnboundedMultiTuple()
346347
elif isinstance(items, bool):
347348
if not default:
348349
raise ValueError("Expected `default` value, got None")
@@ -394,7 +395,7 @@ def __new__(cls, items, default=None):
394395
else:
395396
raise ValueError("Expected bool or iterable, got %s instead" % type(items))
396397

397-
obj = super().__new__(cls, items)
398+
obj = super().__new__(cls, *items)
398399
obj.default = as_tuple(default)
399400

400401
return obj

devito/passes/clusters/blocking.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ class BlockSizeGenerator(object):
431431
"""
432432

433433
def __init__(self, par_tile):
434-
self.umt = UnboundedMultiTuple(*par_tile)
434+
self.umt = par_tile
435435
self.tip = -1
436436

437437
# This is for Clusters that need a small par-tile to avoid under-utilizing
@@ -459,11 +459,11 @@ def next(self, prefix, d, clusters):
459459
return self.umt_small.next()
460460

461461
if x:
462-
item = self.umt.curitem
462+
item = self.umt.curitem()
463463
else:
464464
# We can't `self.umt.iter()` because we might still want to
465465
# fallback to `self.umt_small`
466-
item = self.umt.nextitem
466+
item = self.umt.nextitem()
467467

468468
# Handle user-provided rules
469469
# TODO: This is also rudimentary
@@ -474,15 +474,16 @@ def next(self, prefix, d, clusters):
474474
umt = self.umt
475475
else:
476476
umt = self.umt_small
477+
if not x:
478+
umt.iter()
477479
else:
478480
if item.rule in {d.name for d in prefix.itdims}:
479481
umt = self.umt
480482
else:
481483
# This is like "pattern unmatched" -- fallback to `umt_small`
482484
umt = self.umt_small
483-
484-
if not x:
485-
umt.iter()
485+
if not x:
486+
umt.iter()
486487

487488
return umt.next()
488489

devito/passes/iet/languages/openacc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _make_partree(self, candidates, nthreads=None):
165165
if self._is_offloadable(root) and \
166166
all(i.is_Affine for i in [root] + collapsable) and \
167167
self.par_tile:
168-
tile = self.par_tile.next()
168+
tile = self.par_tile.nextitem()
169169
assert isinstance(tile, UnboundTuple)
170170

171171
body = self.DeviceIteration(gpu_fit=self.gpu_fit, tile=tile,

devito/passes/iet/parpragma.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from devito.passes.iet.langbase import (LangBB, LangTransformer, DeviceAwareMixin,
1616
make_sections_from_imask)
1717
from devito.symbolics import INT, ccode
18-
from devito.tools import UnboundTuple, as_tuple, flatten, is_integer, prod
18+
from devito.tools import as_tuple, flatten, is_integer, prod
1919
from devito.types import Symbol
2020

2121
__all__ = ['PragmaSimdTransformer', 'PragmaShmTransformer',
@@ -622,8 +622,7 @@ def __init__(self, sregistry, options, platform, compiler):
622622
super().__init__(sregistry, options, platform, compiler)
623623

624624
self.gpu_fit = options['gpu-fit']
625-
self.par_tile = UnboundTuple(*options['par-tile'],
626-
default=options['par-tile'].default)
625+
self.par_tile = options['par-tile']
627626
self.par_disabled = options['par-disabled']
628627

629628
def _score_candidate(self, n0, root, collapsable=()):
@@ -659,7 +658,7 @@ def _make_partree(self, candidates, nthreads=None, index=None):
659658
if self._is_offloadable(root):
660659
body = self.DeviceIteration(gpu_fit=self.gpu_fit,
661660
ncollapsed=len(collapsable)+1,
662-
tile=self.par_tile.next(),
661+
tile=self.par_tile.nextitem(),
663662
**root.args)
664663
partree = ParallelTree([], body, nthreads=nthreads)
665664

devito/tools/data_structures.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -670,14 +670,9 @@ def __new__(cls, *items, **kwargs):
670670
obj = super().__new__(cls, tuple(nitems))
671671
obj.last = len(nitems)
672672
obj.current = 0
673-
obj._default = kwargs.get('default', nitems[0])
674673

675674
return obj
676675

677-
@property
678-
def default(self):
679-
return self._default
680-
681676
@property
682677
def prod(self):
683678
return np.prod(self)
@@ -686,11 +681,11 @@ def iter(self):
686681
self.current = 0
687682

688683
def next(self):
689-
if self.last == 0:
684+
if not self:
690685
return None
691686
item = self[self.current]
692687
if self.current == self.last-1 or self.current == -1:
693-
self.current = -1
688+
self.current = self.last
694689
else:
695690
self.current += 1
696691
return item
@@ -703,6 +698,8 @@ def __repr__(self):
703698
return "%s(%s)" % (self.__class__.__name__, ", ".join(sitems))
704699

705700
def __getitem__(self, idx):
701+
if not self:
702+
return None
706703
if isinstance(idx, slice):
707704
start = idx.start or 0
708705
stop = idx.stop or self.last
@@ -754,26 +751,39 @@ class UnboundedMultiTuple(UnboundTuple):
754751

755752
def __new__(cls, *items, **kwargs):
756753
obj = super().__new__(cls, *items, **kwargs)
757-
obj.current = -1
754+
# MultiTuple are un-initialized
755+
obj.current = None
758756
return obj
759757

760-
@property
761758
def curitem(self):
759+
if self.current is None:
760+
raise StopIteration
761+
if not self:
762+
return None
762763
return self[self.current]
763764

764-
@property
765765
def nextitem(self):
766-
return self[min(self.current + 1, max(self.last - 1, 0))]
766+
if not self:
767+
return None
768+
self.iter()
769+
return self.curitem()
767770

768771
def index(self, item):
769772
return self.index(item)
770773

771774
def iter(self):
772-
self.current = min(self.current + 1, self.last - 1)
775+
if self.current is None:
776+
self.current = 0
777+
else:
778+
self.current = min(self.current + 1, self.last - 1)
773779
self[self.current].current = 0
774780
return
775781

776782
def next(self):
777-
if self[self.current].current == -1:
783+
if not self:
784+
return None
785+
if self.current is None:
786+
raise StopIteration
787+
if self[self.current].current >= self[self.current].last:
778788
raise StopIteration
779789
return self[self.current].next()

tests/test_tools.py

+8
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,14 @@ def test_ctypes_to_cstr(dtype, expected):
103103

104104
def test_unbounded_multi_tuple():
105105
ub = UnboundedMultiTuple([1, 2], [3, 4])
106+
with pytest.raises(StopIteration):
107+
ub.next()
108+
109+
with pytest.raises(StopIteration):
110+
assert ub.curitem()
106111

107112
ub.iter()
113+
assert ub.curitem() == (1, 2)
108114
assert ub.next() == 1
109115
assert ub.next() == 2
110116

@@ -121,6 +127,8 @@ def test_unbounded_multi_tuple():
121127
ub.iter()
122128
assert ub.next() == 3
123129

130+
assert ub.nextitem() == (3, 4)
131+
124132

125133
def test_unbound_tuple():
126134
# Make sure we don't drop needed None for 2.5d

tests/test_unexpansion.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,17 @@ def test_v6(self):
275275
op.cfunction
276276

277277
def test_transpose(self):
278-
shape = (10, 10, 10)
279-
grid = Grid(shape=shape)
278+
shape = (11, 11, 11)
279+
grid = Grid(shape=shape, extent=(10, 10, 10))
280280
x, _, _ = grid.dimensions
281281

282282
u = TimeFunction(name='u', grid=grid, space_order=4)
283283
u1 = TimeFunction(name='u', grid=grid, space_order=4)
284284

285285
# Chessboard-like init
286-
u.data[:] = np.indices(shape).sum(axis=0) % 10 + 1
287-
u1.data[:] = np.indices(shape).sum(axis=0) % 10 + 1
286+
hshape = u.data_with_halo.shape[1:]
287+
u.data_with_halo[:] = np.indices(hshape).sum(axis=0) % 10 + 1
288+
u1.data_with_halo[:] = np.indices(hshape).sum(axis=0) % 10 + 1
288289

289290
eqn = Eq(u.forward, u.dx(x0=x+x.spacing/2).T + 1.)
290291

@@ -293,7 +294,6 @@ def test_transpose(self):
293294

294295
op0.apply(time_M=10)
295296
op1.apply(time_M=10, u=u1)
296-
297297
assert np.allclose(u.data, u1.data, rtol=10e-6)
298298

299299

0 commit comments

Comments
 (0)