Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Singletonize special symbols (e.g. nthreads) #1650

Merged
merged 3 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def norm(f, order=2):
# otherwise we would eventually be summing more than expected
p, eqns = f.guard() if f.is_SparseFunction else (f, [])

s = dv.types.Scalar(name='sum', dtype=f.dtype)
s = dv.types.Symbol(name='sum', dtype=f.dtype)

with MPIReduction(f) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] +
Expand Down Expand Up @@ -59,7 +59,7 @@ def sumall(f):
# otherwise we would eventually be summing more than expected
p, eqns = f.guard() if f.is_SparseFunction else (f, [])

s = dv.types.Scalar(name='sum', dtype=f.dtype)
s = dv.types.Symbol(name='sum', dtype=f.dtype)

with MPIReduction(f) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] +
Expand Down Expand Up @@ -113,7 +113,7 @@ def inner(f, g):
# otherwise we would eventually be summing more than expected
rhs, eqns = f.guard(f*g) if f.is_SparseFunction else (f*g, [])

s = dv.types.Scalar(name='sum', dtype=f.dtype)
s = dv.types.Symbol(name='sum', dtype=f.dtype)

with MPIReduction(f, g) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] +
Expand Down
2 changes: 1 addition & 1 deletion devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _autotune(self, args, setup):

@property
def nthreads(self):
nthreads = [i for i in self.input if type(i).__base__ is NThreads]
nthreads = [i for i in self.input if isinstance(i, NThreads)]
if len(nthreads) == 0:
return 1
else:
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ThreadFunction(Callable):
def _make_threads(value, sregistry):
name = sregistry.make_name(prefix='threads')

base_id = 1 + sum(i.data for i in sregistry.npthreads)
base_id = 1 + sum(i.size for i in sregistry.npthreads)

if value is None:
# The npthreads Symbol isn't actually used, but we record the fact
Expand Down
4 changes: 2 additions & 2 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from devito.symbolics import retrieve_function_carriers, indexify, INT
from devito.tools import powerset, flatten, prod
from devito.types import (ConditionalDimension, Dimension, DefaultDimension, Eq, Inc,
Evaluable, Scalar, SubFunction)
Evaluable, Symbol, SubFunction)

__all__ = ['LinearInterpolator', 'PrecomputedInterpolator']

Expand Down Expand Up @@ -234,7 +234,7 @@ def callback():
for b, v_sub in zip(self._interpolation_coeffs, idx_subs)]

# Accumulate point-wise contributions into a temporary
rhs = Scalar(name='sum', dtype=self.sfunction.dtype)
rhs = Symbol(name='sum', dtype=self.sfunction.dtype)
summands = [Eq(rhs, 0., implicit_dims=self.sfunction.dimensions)]
summands.extend([Inc(rhs, i, implicit_dims=self.sfunction.dimensions)
for i in args])
Expand Down
10 changes: 5 additions & 5 deletions devito/operator/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def __init__(self):
self.counters = {}

# Special symbols
self.nthreads = NThreads(aliases='nthreads0')
self.nthreads_nested = NThreadsNested(aliases='nthreads1')
self.nthreads_nonaffine = NThreadsNonaffine(aliases='nthreads2')
self.nthreads = NThreads()
self.nthreads_nested = NThreadsNested()
self.nthreads_nonaffine = NThreadsNonaffine()
self.threadid = ThreadID(self.nthreads)

# Several groups of pthreads each of size `npthread` may be created
Expand All @@ -36,8 +36,8 @@ def make_name(self, prefix=None):

return "%s%d" % (prefix, counter())

def make_npthreads(self, value):
def make_npthreads(self, size):
name = self.make_name(prefix='npthreads')
npthreads = NPThreads(name=name, value=value)
npthreads = NPThreads(name=name, size=size)
self.npthreads.append(npthreads)
return npthreads
6 changes: 3 additions & 3 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from devito.symbolics import (Uxmapper, compare_ops, estimate_cost, q_constant,
q_leaf, retrieve_indexed, search, uxreplace)
from devito.tools import as_tuple, flatten, split
from devito.types import (Array, TempFunction, Eq, Scalar, ModuloDimension,
from devito.types import (Array, TempFunction, Eq, Symbol, ModuloDimension,
CustomDimension, IncrDimension)

__all__ = ['cire']
Expand Down Expand Up @@ -172,7 +172,7 @@ def make_schedule(self, cluster, context):
return SpacePoint(schedule, exprs, score)

def _make_symbol(self):
return Scalar(name=self.sregistry.make_name('dummy'))
return Symbol(name=self.sregistry.make_name('dummy'))

def _nrepeats(self, cluster):
raise NotImplementedError
Expand Down Expand Up @@ -801,7 +801,7 @@ def lower_schedule(cluster, schedule, sregistry, options):
# Degenerate case: scalar expression
assert writeto.size == 0

obj = Scalar(name=name, dtype=dtype)
obj = Symbol(name=name, dtype=dtype)
expression = Eq(obj, alias)

callback = lambda idx: obj
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from devito.ir import DummyEq, Cluster, Scope
from devito.passes.clusters.utils import cluster_pass, makeit_ssa
from devito.symbolics import count, estimate_cost, q_xop, q_leaf, uxreplace
from devito.types import Scalar
from devito.types import Symbol

__all__ = ['cse']

Expand All @@ -13,7 +13,7 @@ def cse(cluster, sregistry, *args):
"""
Common sub-expressions elimination (CSE).
"""
make = lambda: Scalar(name=sregistry.make_name(), dtype=cluster.dtype).indexify()
make = lambda: Symbol(name=sregistry.make_name(), dtype=cluster.dtype).indexify()
processed = _cse(cluster.exprs, make)

return cluster.rebuild(processed)
Expand Down
8 changes: 4 additions & 4 deletions devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sympy import Add, Mul, collect

from devito.passes.clusters.utils import cluster_pass
from devito.symbolics import estimate_cost, retrieve_scalars
from devito.symbolics import estimate_cost, retrieve_symbols
from devito.tools import ReducerMap

__all__ = ['factorize']
Expand Down Expand Up @@ -155,9 +155,9 @@ def run(expr):

# Collect common temporaries (r0, r1, ...)
w_coeffs = Add(*w_coeffs, evaluate=False)
scalars = retrieve_scalars(w_coeffs)
if scalars:
w_coeffs = collect(w_coeffs, scalars, evaluate=False)
symbols = retrieve_symbols(w_coeffs)
if symbols:
w_coeffs = collect(w_coeffs, symbols, evaluate=False)
try:
terms.extend([Mul(k, collect_const(v), evaluate=False)
for k, v in w_coeffs.items()])
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.passes.clusters.utils import cluster_pass
from devito.symbolics import pow_to_mul, uxreplace
from devito.tools import DAG, as_tuple, filter_ordered, frozendict, timed_pass
from devito.types import Scalar
from devito.types import Symbol

__all__ = ['Lift', 'fuse', 'eliminate_arrays', 'optimize_pows', 'extract_increments']

Expand Down Expand Up @@ -316,7 +316,7 @@ def extract_increments(cluster, sregistry, *args):
processed = []
for e in cluster.exprs:
if e.is_Increment and e.lhs.function.is_Input:
handle = Scalar(name=sregistry.make_name(), dtype=e.dtype).indexify()
handle = Symbol(name=sregistry.make_name(), dtype=e.dtype).indexify()
if e.rhs.is_Number or e.rhs.is_Symbol:
extracted = e.rhs
else:
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def filter_args(v, efunc=None):
continue

if efunc is self.root and not (a.is_Input or a.is_Object):
# Temporaries (ie, Scalars, Arrays) *cannot* be args in `root`
# Temporaries (ie, Symbol, Arrays) *cannot* be args in `root`
continue

processed.append(a)
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.passes.iet.langbase import LangBB, LangTransformer, DeviceAwareMixin
from devito.passes.iet.misc import is_on_device
from devito.tools import as_tuple, is_integer, prod
from devito.types import Symbol, NThreadsMixin
from devito.types import Symbol, NThreadsBase

__all__ = ['PragmaSimdTransformer', 'PragmaShmTransformer',
'PragmaDeviceAwareTransformer', 'PragmaLangBB']
Expand Down Expand Up @@ -360,7 +360,7 @@ def _make_parallel(self, iet):
iet = Transformer(mapper).visit(iet)

# The new arguments introduced by this pass
args = [i for i in FindSymbols().visit(iet) if isinstance(i, (NThreadsMixin))]
args = [i for i in FindSymbols().visit(iet) if isinstance(i, (NThreadsBase))]
for n in FindNodes(VExpanded).visit(iet):
args.extend([(n.pointee, True), n.pointer])

Expand Down
6 changes: 3 additions & 3 deletions devito/symbolics/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

__all__ = ['q_leaf', 'q_indexed', 'q_terminal', 'q_function', 'q_routine', 'q_xop',
'q_terminalop', 'q_indirect', 'q_constant', 'q_affine', 'q_linear',
'q_identity', 'q_inc', 'q_scalar', 'q_multivar', 'q_monoaffine',
'q_identity', 'q_inc', 'q_symbol', 'q_multivar', 'q_monoaffine',
'q_dimension']


Expand All @@ -16,9 +16,9 @@
# * Indexed


def q_scalar(expr):
def q_symbol(expr):
try:
return expr.is_Scalar
return expr.is_Symbol
except AttributeError:
return False

Expand Down
8 changes: 4 additions & 4 deletions devito/symbolics/search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf, q_xop,
q_scalar, q_dimension)
q_symbol, q_dimension)
from devito.tools import as_tuple

__all__ = ['retrieve_indexed', 'retrieve_functions', 'retrieve_function_carriers',
'retrieve_terminals', 'retrieve_xops', 'retrieve_scalars',
'retrieve_terminals', 'retrieve_xops', 'retrieve_symbols',
'retrieve_dimensions', 'search']


Expand Down Expand Up @@ -139,9 +139,9 @@ def retrieve_functions(exprs, mode='all'):
return search(exprs, q_function, mode, 'dfs')


def retrieve_scalars(exprs, mode='all'):
def retrieve_symbols(exprs, mode='all'):
"""Shorthand to retrieve the Scalar in ``exprs``."""
return search(exprs, q_scalar, mode, 'dfs')
return search(exprs, q_symbol, mode, 'dfs')


def retrieve_function_carriers(exprs, mode='all'):
Expand Down
20 changes: 17 additions & 3 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,14 +466,28 @@ class Scalar(Symbol, ArgProvider):
def __dtype_setup__(cls, **kwargs):
return kwargs.get('dtype', np.float32)

def _arg_defaults(self):
return {}
@property
def default_value(self):
return None

@property
def _arg_names(self):
return (self.name,)

def _arg_defaults(self, **kwargs):
if self.default_value is None:
# It is possible that the Scalar value is provided indirectly
# through a wrapper object (e.g., a Dimension spacing `h_x` gets its
# value via a Grid object)
return {}
else:
return {self.name: self.default_value}

def _arg_values(self, **kwargs):
if self.name in kwargs:
return {self.name: kwargs.pop(self.name)}
else:
return {}
return self._arg_defaults()


class AbstractTensor(sympy.ImmutableDenseMatrix, Basic, Pickable, Evaluable):
Expand Down
Loading