diff --git a/numba_cuda/numba/cuda/core/caching.py b/numba_cuda/numba/cuda/core/caching.py index 0158c2b1b..a7f6460dc 100644 --- a/numba_cuda/numba/cuda/core/caching.py +++ b/numba_cuda/numba/cuda/core/caching.py @@ -18,7 +18,7 @@ from pathlib import Path from numba.core import config -from numba.core.serialize import dumps +from numba.cuda.serialize import dumps def _cache_log(msg, *args): diff --git a/numba_cuda/numba/cuda/core/environment.py b/numba_cuda/numba/cuda/core/environment.py new file mode 100644 index 000000000..a1d318e04 --- /dev/null +++ b/numba_cuda/numba/cuda/core/environment.py @@ -0,0 +1,65 @@ +import weakref +import importlib + +from numba import _dynfunc + + +class Environment(_dynfunc.Environment): + """Stores globals and constant pyobjects for runtime. + + It is often needed to convert b/w nopython objects and pyobjects. + """ + + __slots__ = ("env_name", "__weakref__") + # A weak-value dictionary to store live environment with env_name as the + # key. + _memo = weakref.WeakValueDictionary() + + @classmethod + def from_fndesc(cls, fndesc): + try: + # Avoid creating new Env + return cls._memo[fndesc.env_name] + except KeyError: + inst = cls(fndesc.lookup_globals()) + inst.env_name = fndesc.env_name + cls._memo[fndesc.env_name] = inst + return inst + + def can_cache(self): + is_dyn = "__name__" not in self.globals + return not is_dyn + + def __reduce__(self): + return _rebuild_env, ( + self.globals.get("__name__"), + self.consts, + self.env_name, + ) + + def __del__(self): + return + + def __repr__(self): + return f"" + + +def _rebuild_env(modname, consts, env_name): + env = lookup_environment(env_name) + if env is not None: + return env + + mod = importlib.import_module(modname) + env = Environment(mod.__dict__) + env.consts[:] = consts + env.env_name = env_name + # Cache loaded object + Environment._memo[env_name] = env + return env + + +def lookup_environment(env_name): + """Returns the Environment object for the given name; + or None if not found + """ + return Environment._memo.get(env_name) diff --git a/numba_cuda/numba/cuda/lowering.py b/numba_cuda/numba/cuda/lowering.py index aeee3dfd8..8cdfe104e 100644 --- a/numba_cuda/numba/cuda/lowering.py +++ b/numba_cuda/numba/cuda/lowering.py @@ -31,7 +31,7 @@ NumbaDebugInfoWarning, ) from numba.core.funcdesc import default_mangler -from numba.core.environment import Environment +from numba.cuda.core.environment import Environment from numba.core.analysis import compute_use_defs, must_use_alloca from numba.misc.firstlinefinder import get_func_body_first_lineno from numba import version_info diff --git a/numba_cuda/numba/cuda/tests/core/test_serialize.py b/numba_cuda/numba/cuda/tests/core/test_serialize.py index 394fd9be7..afcbe6aec 100644 --- a/numba_cuda/numba/cuda/tests/core/test_serialize.py +++ b/numba_cuda/numba/cuda/tests/core/test_serialize.py @@ -230,7 +230,7 @@ def foo(x): class TestSerializationMisc(TestCase): def test_numba_unpickle(self): # Test that _numba_unpickle is memorizing its output - from numba.core.serialize import _numba_unpickle + from numba.cuda.serialize import _numba_unpickle random_obj = object() bytebuf = pickle.dumps(random_obj)