Skip to content

Commit

Permalink
Merge pull request #1913 from devitocodes/uncache-data-carriers
Browse files Browse the repository at this point in the history
compiler: Uncache data carriers
  • Loading branch information
FabioLuporini authored May 23, 2022
2 parents fb22b58 + 5b9adbe commit 3bf0a19
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 41 deletions.
46 changes: 12 additions & 34 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_cstr,
dtype_to_ctype, frozendict, memoized_meth, sympy_mutex)
from devito.types.args import ArgProvider
from devito.types.caching import Cached
from devito.types.caching import Cached, Uncached
from devito.types.lazy import Evaluable
from devito.types.utils import DimensionTuple

Expand Down Expand Up @@ -417,49 +417,27 @@ def __new__(cls, *args, **kwargs):
__hash__ = Cached.__hash__


class DataSymbol(AbstractSymbol, Cached):
class DataSymbol(AbstractSymbol, Uncached):

"""
A scalar symbol, cached by both Devito and SymPy, which carries data.
A unique scalar symbol that carries data.
"""

@classmethod
def _cache_key(cls, *args, **kwargs):
return cls

def __new__(cls, *args, **kwargs):
key = cls._cache_key(*args, **kwargs)
obj = cls._cache_get(key)

if obj is not None:
return obj

# Not in cache. Create a new Symbol via sympy.Symbol
# Create a new Symbol via sympy.Symbol
name = kwargs.get('name') or args[0]
assumptions, kwargs = cls._filter_assumptions(**kwargs)

# Create new, unique type instance from cls and the symbol name
newcls = type(name, (cls,), dict(cls.__dict__))

# Create the new Symbol and invoke __init__
newobj = sympy.Symbol.__new__(newcls, name, **assumptions)
# Note: use __xnew__ to bypass sympy caching
newobj = sympy.Symbol.__xnew__(cls, name, **assumptions)

# Initialization
newobj._dtype = cls.__dtype_setup__(**kwargs)
newobj.__init_finalize__(*args, **kwargs)

# Store new instance in symbol cache
Cached.__init__(newobj, newcls)

return newobj

__hash__ = Cached.__hash__

# Pickling support

@property
def _pickle_reconstruct(self):
return self.__class__.__base__
__hash__ = Uncached.__hash__


class Scalar(Symbol, ArgProvider):
Expand Down Expand Up @@ -1234,7 +1212,7 @@ def function(self):
__reduce_ex__ = Pickable.__reduce_ex__


class Object(AbstractObject, ArgProvider):
class Object(AbstractObject, ArgProvider, Uncached):

"""
Pointer to object with derived type, provided by an outer scope.
Expand All @@ -1246,6 +1224,8 @@ def __init__(self, name, dtype, value=None):
super(Object, self).__init__(name, dtype)
self.value = value

__hash__ = Uncached.__hash__

@property
def _arg_names(self):
return (self.name,)
Expand All @@ -1266,7 +1246,8 @@ def _arg_values(self, **kwargs):
Dictionary of user-provided argument overrides.
"""
if self.name in kwargs:
return {self.name: kwargs.pop(self.name)}
obj = kwargs.pop(self.name)
return {self.name: obj._arg_defaults()[obj.name]}
else:
return self._arg_defaults()

Expand Down Expand Up @@ -1306,9 +1287,6 @@ def pname(self):
def fields(self):
return [i for i, _ in self.pfields]

def _hashable_content(self):
return (self.name, self.pfields)

@cached_property
def _C_typedecl(self):
return Struct(self.pname, [Value(ctypes_to_cstr(j), i) for i, j in self.pfields])
Expand Down
13 changes: 12 additions & 1 deletion devito/types/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from devito.tools import safe_dict_copy


__all__ = ['Cached', '_SymbolCache', 'CacheManager']
__all__ = ['Cached', 'Uncached', '_SymbolCache', 'CacheManager']

_SymbolCache = {}
"""The symbol cache."""
Expand All @@ -21,6 +21,17 @@ def __new__(cls, obj, meta):
return obj


class Uncached(object):

"""
Mixin class for unique, and therefore uncached, symbolic objects
(e.g., data carriers).
"""

def __hash__(self):
return id(self)


class Cached(object):

"""
Expand Down
6 changes: 3 additions & 3 deletions examples/compiler/04_iet-B.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"text": [
"<Expression a = b + c[i] + 5.0>\n",
"<Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>\n",
"<Expression a = 4*a*b>\n",
"<Expression a = 4*b*a>\n",
"<Expression a = 8.0*a + 6.0/b>\n"
]
}
Expand Down Expand Up @@ -184,7 +184,7 @@
" <Iteration j::j::(0, 5, 1)>\n",
" <Iteration k::k::(0, 7, 1)>\n",
" <Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>\n",
" <Expression a = 4*a*b>\n",
" <Expression a = 4*b*a>\n",
" <Iteration t1::t1::(0, 4, 1)>\n",
" <Expression a = 8.0*a + 6.0/b>\n"
]
Expand Down Expand Up @@ -293,7 +293,7 @@
" for (int k = 0; k <= 7; k += 1)\n",
" {\n",
" d[j][k] = e[t0][t1][i] - f[t][x][y];\n",
" a = 4*a*b;\n",
" a = 4*b*a;\n",
" }\n",
" }\n",
" for (int t1 = 0; t1 <= 4; t1 += 1)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/performance/01_gpu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@
" for (int y = y_m; y <= y_M; y += 1)\n",
" {\n",
" float r3 = -2.0F*u[time][x + 2][y + 2];\n",
" u[time + 1][x + 2][y + 2] = dt*(r0*u[time][x + 2][y + 2] + c*(r1*r3 + r1*u[time][x + 1][y + 2] + r1*u[time][x + 3][y + 2] + r2*r3 + r2*u[time][x + 2][y + 1] + r2*u[time][x + 2][y + 3]));\n",
" u[time + 1][x + 2][y + 2] = dt*(c*(r1*r3 + r1*u[time][x + 1][y + 2] + r1*u[time][x + 3][y + 2] + r2*r3 + r2*u[time][x + 2][y + 1] + r2*u[time][x + 2][y + 3]) + r0*u[time][x + 2][y + 2]);\n",
" }\n",
" }\n",
" STOP_TIMER(section0,timers)\n",
Expand Down
28 changes: 26 additions & 2 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ctypes import byref, c_void_p
import weakref

import numpy as np
Expand All @@ -7,7 +8,8 @@
ConditionalDimension, SubDimension, Constant, Operator, Eq, Dimension,
DefaultDimension, _SymbolCache, clear_cache, solve, VectorFunction,
TensorFunction, TensorTimeFunction, VectorTimeFunction)
from devito.types import Scalar, Symbol, NThreadsBase, DeviceID, NPThreads, ThreadID
from devito.types import (DeviceID, NThreadsBase, NPThreads, Object, Scalar, Symbol,
ThreadID)


@pytest.fixture
Expand Down Expand Up @@ -168,6 +170,28 @@ def test_bound_symbol(self):
assert hash(u0._C_symbol) != hash(u1._C_symbol)
assert u0._C_symbol != u1._C_symbol

def test_objects(self):
v0 = byref(c_void_p(3))
v1 = byref(c_void_p(4))

dtype = type('Bar', (c_void_p,), {})

foo0 = Object('foo', dtype, v0)
foo1 = Object('foo', dtype, v0)
foo2 = Object('foo', dtype, v1)

# Obviously:
assert foo0 is not foo1
assert foo0 is not foo2
assert foo1 is not foo2

# Carried value doesn't matter -- an Object is always unique
assert hash(foo0) != hash(foo1)

# And obviously:
assert hash(foo0) != hash(foo2)
assert hash(foo1) != hash(foo2)


class TestCaching(object):

Expand Down Expand Up @@ -358,7 +382,7 @@ def test_default_dimension(self):
assert dd0 is not dd1

def test_constant_new(self):
"""Test that new u[x, y] instances don't cache"""
"""Test that new Constant instances don't cache."""
u0 = Constant(name='u')
u0.data = 6.
u1 = Constant(name='u')
Expand Down

0 comments on commit 3bf0a19

Please sign in to comment.