Skip to content

Commit

Permalink
types: Fixup generation, caching and processing of NPThreads
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Apr 6, 2021
1 parent e9dc85f commit 71de00d
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 12 deletions.
2 changes: 1 addition & 1 deletion devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ThreadFunction(Callable):
def _make_threads(value, sregistry):
name = sregistry.make_name(prefix='threads')

base_id = 1 + sum(i.data for i in sregistry.npthreads)
base_id = 1 + sum(i.size for i in sregistry.npthreads)

if value is None:
# The npthreads Symbol isn't actually used, but we record the fact
Expand Down
4 changes: 2 additions & 2 deletions devito/operator/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def make_name(self, prefix=None):

return "%s%d" % (prefix, counter())

def make_npthreads(self, value):
def make_npthreads(self, size):
name = self.make_name(prefix='npthreads')
npthreads = NPThreads(name=name, value=value)
npthreads = NPThreads(name=name, size=size)
self.npthreads.append(npthreads)
return npthreads
25 changes: 22 additions & 3 deletions devito/types/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import sympy

from devito.exceptions import InvalidArgument
from devito.parameters import configuration
from devito.tools import Pickable, as_list, as_tuple, dtype_to_cstr, filter_ordered
from devito.types.array import Array, ArrayObject
Expand Down Expand Up @@ -70,9 +71,27 @@ class NPThreads(NThreadsBase):

name = 'npthreads'

@property
def default_value(self):
return 1
def __new__(cls, **kwargs):
obj = super().__new__(cls, **kwargs)

# Size of the thread pool
obj.size = kwargs['size']

return obj

def _arg_values(self, **kwargs):
if self.name in kwargs:
v = kwargs.pop(self.name)
if v < self.size:
return {self.name: v}
else:
raise InvalidArgument("Illegal `%s=%d`. It must be `%s<%d`"
% (self.name, v, self.name, self.size))
else:
return self._arg_defaults()

# Pickling support
_pickle_kwargs = NThreadsBase._pickle_kwargs + ['size']


class ThreadID(CustomDimension):
Expand Down
8 changes: 7 additions & 1 deletion tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ConditionalDimension, SubDimension, Constant, Operator, Eq, Dimension,
DefaultDimension, _SymbolCache, clear_cache, solve, VectorFunction,
TensorFunction, TensorTimeFunction, VectorTimeFunction)
from devito.types import Scalar, Symbol, NThreadsBase, DeviceID, ThreadID
from devito.types import Scalar, Symbol, NThreadsBase, DeviceID, NPThreads, ThreadID


@pytest.fixture
Expand Down Expand Up @@ -409,6 +409,12 @@ def test_special_symbols(self):
did1 = DeviceID()
assert did0 is did1

npt0 = NPThreads(name='npt', size=3)
npt1 = NPThreads(name='npt', size=3)
npt2 = NPThreads(name='npt', size=4)
assert npt0 is npt1
assert npt0 is not npt2

def test_symbol_aliasing(self):
"""Test to assert that our aliasing cache isn't defeated by sympys
non-aliasing symbol cache.
Expand Down
30 changes: 26 additions & 4 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension,
SubDimension, SubDomain, TimeFunction, Operator)
from devito.arch import get_gpu_info
from devito.exceptions import InvalidArgument
from devito.ir import Expression, Section, FindNodes, FindSymbols, retrieve_iteration_tree
from devito.passes.iet.languages.openmp import OmpIteration
from devito.types import DeviceID, DeviceRM, Lock, PThreadArray
from devito.types import DeviceID, DeviceRM, Lock, NPThreads, PThreadArray

from conftest import skipif

Expand Down Expand Up @@ -437,7 +438,7 @@ def test_composite_buffering_tasking(self):
assert len([i for i in symbols if isinstance(i, Lock)]) == 1
threads = [i for i in symbols if isinstance(i, PThreadArray)]
assert len(threads) == 1
assert threads[0].size.data == 1
assert threads[0].size.size == 1

op0.apply(time_M=nt-1, dt=0.1)
op1.apply(time_M=nt-1, dt=0.1, u=u1, usave=usave1)
Expand Down Expand Up @@ -475,8 +476,8 @@ def test_composite_buffering_tasking_multi_output(self):
assert len([i for i in symbols if isinstance(i, Lock)]) == 2
threads = [i for i in symbols if isinstance(i, PThreadArray)]
assert len(threads) == 2
assert threads[0].size.data == 1
assert threads[1].size.data == 1
assert threads[0].size.size == 1
assert threads[1].size.size == 1
assert len(op1._func_table) == 4 # usave and vsave eqns are in two diff efuncs

op0.apply(time_M=nt-1)
Expand Down Expand Up @@ -753,3 +754,24 @@ def test_devicerm(self):
assert op.arguments(time_M=2, devicerm=224)[devicerm.name] == 1
assert op.arguments(time_M=2, devicerm=True)[devicerm.name] == 1
assert op.arguments(time_M=2, devicerm=False)[devicerm.name] == 0

def test_npthreads(self):
nt = 10
async_degree = 5
grid = Grid(shape=(300, 300, 300))

u = TimeFunction(name='u', grid=grid)
usave = TimeFunction(name='usave', grid=grid, save=nt)

eqns = [Eq(u.forward, u + 1),
Eq(usave, u.forward)]

op = Operator(eqns, opt=('buffering', 'tasking', 'orchestrate',
{'buf-async-degree': async_degree}))

npthreads0 = self.get_param(op, NPThreads)
assert op.arguments(time_M=2)[npthreads0.name] == 1
assert op.arguments(time_M=2, npthreads0=4)[npthreads0.name] == 4
# Cannot provide a value larger than the thread pool size
with pytest.raises(InvalidArgument):
assert op.arguments(time_M=2, npthreads0=5)
13 changes: 12 additions & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MPIRegion)
from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar,
PointerArray, Lock, PThreadArray, SharedData, Timer,
DeviceID, ThreadID, TempFunction)
DeviceID, NPThreads, ThreadID, TempFunction)
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
FunctionFromPointer, DefFunction)
from examples.seismic import (demo_model, AcquisitionGeometry,
Expand Down Expand Up @@ -530,6 +530,17 @@ def test_deviceid():
assert did.dtype == new_did.dtype


def test_npthreads():
npt = NPThreads(name='npt', size=3)

pkl_npt = pickle.dumps(npt)
new_npt = pickle.loads(pkl_npt)

assert npt.name == new_npt.name
assert npt.dtype == new_npt.dtype
assert npt.size == new_npt.size


@skipif(['nompi'])
@pytest.mark.parallel(mode=[(1, 'full')])
def test_mpi_fullmode_objects():
Expand Down

0 comments on commit 71de00d

Please sign in to comment.