Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Uncache data carriers #1913

Merged
merged 4 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 12 additions & 34 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_cstr,
dtype_to_ctype, frozendict, memoized_meth, sympy_mutex)
from devito.types.args import ArgProvider
from devito.types.caching import Cached
from devito.types.caching import Cached, Uncached
from devito.types.lazy import Evaluable
from devito.types.utils import DimensionTuple

Expand Down Expand Up @@ -417,49 +417,27 @@ def __new__(cls, *args, **kwargs):
__hash__ = Cached.__hash__


class DataSymbol(AbstractSymbol, Cached):
class DataSymbol(AbstractSymbol, Uncached):

"""
A scalar symbol, cached by both Devito and SymPy, which carries data.
A unique scalar symbol that carries data.
"""

@classmethod
def _cache_key(cls, *args, **kwargs):
return cls

def __new__(cls, *args, **kwargs):
key = cls._cache_key(*args, **kwargs)
obj = cls._cache_get(key)

if obj is not None:
return obj

# Not in cache. Create a new Symbol via sympy.Symbol
# Create a new Symbol via sympy.Symbol
name = kwargs.get('name') or args[0]
assumptions, kwargs = cls._filter_assumptions(**kwargs)

# Create new, unique type instance from cls and the symbol name
newcls = type(name, (cls,), dict(cls.__dict__))

# Create the new Symbol and invoke __init__
newobj = sympy.Symbol.__new__(newcls, name, **assumptions)
# Note: use __xnew__ to bypass sympy caching
newobj = sympy.Symbol.__xnew__(cls, name, **assumptions)

# Initialization
newobj._dtype = cls.__dtype_setup__(**kwargs)
newobj.__init_finalize__(*args, **kwargs)

# Store new instance in symbol cache
Cached.__init__(newobj, newcls)

return newobj

__hash__ = Cached.__hash__

# Pickling support

@property
def _pickle_reconstruct(self):
return self.__class__.__base__
__hash__ = Uncached.__hash__


class Scalar(Symbol, ArgProvider):
Expand Down Expand Up @@ -1234,7 +1212,7 @@ def function(self):
__reduce_ex__ = Pickable.__reduce_ex__


class Object(AbstractObject, ArgProvider):
class Object(AbstractObject, ArgProvider, Uncached):

"""
Pointer to object with derived type, provided by an outer scope.
Expand All @@ -1246,6 +1224,8 @@ def __init__(self, name, dtype, value=None):
super(Object, self).__init__(name, dtype)
self.value = value

__hash__ = Uncached.__hash__

@property
def _arg_names(self):
return (self.name,)
Expand All @@ -1266,7 +1246,8 @@ def _arg_values(self, **kwargs):
Dictionary of user-provided argument overrides.
"""
if self.name in kwargs:
return {self.name: kwargs.pop(self.name)}
obj = kwargs.pop(self.name)
return {self.name: obj._arg_defaults()[obj.name]}
else:
return self._arg_defaults()

Expand Down Expand Up @@ -1306,9 +1287,6 @@ def pname(self):
def fields(self):
return [i for i, _ in self.pfields]

def _hashable_content(self):
return (self.name, self.pfields)

@cached_property
def _C_typedecl(self):
return Struct(self.pname, [Value(ctypes_to_cstr(j), i) for i, j in self.pfields])
Expand Down
13 changes: 12 additions & 1 deletion devito/types/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from devito.tools import safe_dict_copy


__all__ = ['Cached', '_SymbolCache', 'CacheManager']
__all__ = ['Cached', 'Uncached', '_SymbolCache', 'CacheManager']

_SymbolCache = {}
"""The symbol cache."""
Expand All @@ -21,6 +21,17 @@ def __new__(cls, obj, meta):
return obj


class Uncached(object):

"""
Mixin class for unique, and therefore uncached, symbolic objects
(e.g., data carriers).
"""

def __hash__(self):
return id(self)


class Cached(object):

"""
Expand Down
6 changes: 3 additions & 3 deletions examples/compiler/04_iet-B.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"text": [
"<Expression a = b + c[i] + 5.0>\n",
"<Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>\n",
"<Expression a = 4*a*b>\n",
"<Expression a = 4*b*a>\n",
"<Expression a = 8.0*a + 6.0/b>\n"
]
}
Expand Down Expand Up @@ -184,7 +184,7 @@
" <Iteration j::j::(0, 5, 1)>\n",
" <Iteration k::k::(0, 7, 1)>\n",
" <Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>\n",
" <Expression a = 4*a*b>\n",
" <Expression a = 4*b*a>\n",
" <Iteration t1::t1::(0, 4, 1)>\n",
" <Expression a = 8.0*a + 6.0/b>\n"
]
Expand Down Expand Up @@ -293,7 +293,7 @@
" for (int k = 0; k <= 7; k += 1)\n",
" {\n",
" d[j][k] = e[t0][t1][i] - f[t][x][y];\n",
" a = 4*a*b;\n",
" a = 4*b*a;\n",
" }\n",
" }\n",
" for (int t1 = 0; t1 <= 4; t1 += 1)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/performance/01_gpu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@
" for (int y = y_m; y <= y_M; y += 1)\n",
" {\n",
" float r3 = -2.0F*u[time][x + 2][y + 2];\n",
" u[time + 1][x + 2][y + 2] = dt*(r0*u[time][x + 2][y + 2] + c*(r1*r3 + r1*u[time][x + 1][y + 2] + r1*u[time][x + 3][y + 2] + r2*r3 + r2*u[time][x + 2][y + 1] + r2*u[time][x + 2][y + 3]));\n",
" u[time + 1][x + 2][y + 2] = dt*(c*(r1*r3 + r1*u[time][x + 1][y + 2] + r1*u[time][x + 3][y + 2] + r2*r3 + r2*u[time][x + 2][y + 1] + r2*u[time][x + 2][y + 3]) + r0*u[time][x + 2][y + 2]);\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change of order deterministic?

" }\n",
" }\n",
" STOP_TIMER(section0,timers)\n",
Expand Down
28 changes: 26 additions & 2 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ctypes import byref, c_void_p
import weakref

import numpy as np
Expand All @@ -7,7 +8,8 @@
ConditionalDimension, SubDimension, Constant, Operator, Eq, Dimension,
DefaultDimension, _SymbolCache, clear_cache, solve, VectorFunction,
TensorFunction, TensorTimeFunction, VectorTimeFunction)
from devito.types import Scalar, Symbol, NThreadsBase, DeviceID, NPThreads, ThreadID
from devito.types import (DeviceID, NThreadsBase, NPThreads, Object, Scalar, Symbol,
ThreadID)


@pytest.fixture
Expand Down Expand Up @@ -168,6 +170,28 @@ def test_bound_symbol(self):
assert hash(u0._C_symbol) != hash(u1._C_symbol)
assert u0._C_symbol != u1._C_symbol

def test_objects(self):
v0 = byref(c_void_p(3))
v1 = byref(c_void_p(4))

dtype = type('Bar', (c_void_p,), {})

foo0 = Object('foo', dtype, v0)
foo1 = Object('foo', dtype, v0)
foo2 = Object('foo', dtype, v1)

# Obviously:
assert foo0 is not foo1
assert foo0 is not foo2
assert foo1 is not foo2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the dtype and value be checked too for safety?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"checked" how? like foo0.dtype is foo1.dtype ? that's obvious imho


# Carried value doesn't matter -- an Object is always unique
assert hash(foo0) != hash(foo1)

# And obviously:
assert hash(foo0) != hash(foo2)
assert hash(foo1) != hash(foo2)


class TestCaching(object):

Expand Down Expand Up @@ -358,7 +382,7 @@ def test_default_dimension(self):
assert dd0 is not dd1

def test_constant_new(self):
"""Test that new u[x, y] instances don't cache"""
"""Test that new Constant instances don't cache."""
u0 = Constant(name='u')
u0.data = 6.
u1 = Constant(name='u')
Expand Down