diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index dfca4f493..258c3c46d 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -1,6 +1,7 @@ from llvmlite import ir -from numba.core import config, serialize +from numba.core import config +from numba.cuda import serialize from .cudadrv import devices, driver, nvvm, runtime, nvrtc from numba.cuda.core.codegen import Codegen, CodeLibrary from numba.cuda.cudadrv.libs import get_cudalib diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 6cc140757..aa980c481 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -44,8 +44,8 @@ from numba import mviewbuf -from numba.core import serialize, config -from numba.cuda import utils +from numba.core import config +from numba.cuda import utils, serialize from .error import CudaSupportError, CudaDriverError from .drvapi import API_PROTOTYPES from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index f75444079..205ca3ecd 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -8,8 +8,8 @@ import weakref import uuid -from numba.core import compiler, serialize, sigutils, types, typing, config -from numba.cuda import utils +from numba.core import compiler, sigutils, types, typing, config +from numba.cuda import serialize, utils from numba.cuda.core.caching import Cache, CacheImpl, NullCache from numba.core.compiler_lock import global_compiler_lock from numba.core.dispatcher import _DispatcherBase diff --git a/numba_cuda/numba/cuda/serialize.py b/numba_cuda/numba/cuda/serialize.py new file mode 100644 index 000000000..0c23198d8 --- /dev/null +++ b/numba_cuda/numba/cuda/serialize.py @@ -0,0 +1,264 @@ +""" +Serialization support for compiled functions. +""" + +import sys +import abc +import io +import copyreg + + +import pickle +from numba import cloudpickle +from llvmlite import ir + + +# +# Pickle support +# + + +def _rebuild_reduction(cls, *args): + """ + Global hook to rebuild a given class from its __reduce__ arguments. + """ + return cls._rebuild(*args) + + +# Keep unpickled object via `numba_unpickle` alive. +_unpickled_memo = {} + + +def _numba_unpickle(address, bytedata, hashed): + """Used by `numba_unpickle` from _helperlib.c + + Parameters + ---------- + address : int + bytedata : bytes + hashed : bytes + + Returns + ------- + obj : object + unpickled object + """ + key = (address, hashed) + try: + obj = _unpickled_memo[key] + except KeyError: + _unpickled_memo[key] = obj = cloudpickle.loads(bytedata) + return obj + + +def dumps(obj): + """Similar to `pickle.dumps()`. Returns the serialized object in bytes.""" + pickler = NumbaPickler + with io.BytesIO() as buf: + p = pickler(buf, protocol=4) + p.dump(obj) + pickled = buf.getvalue() + + return pickled + + +def runtime_build_excinfo_struct(static_exc, exc_args): + exc, static_args, locinfo = cloudpickle.loads(static_exc) + real_args = [] + exc_args_iter = iter(exc_args) + for arg in static_args: + if isinstance(arg, ir.Value): + real_args.append(next(exc_args_iter)) + else: + real_args.append(arg) + return (exc, tuple(real_args), locinfo) + + +# Alias to pickle.loads to allow `serialize.loads()` +loads = cloudpickle.loads + + +class _CustomPickled: + """A wrapper for objects that must be pickled with `NumbaPickler`. + + Standard `pickle` will pick up the implementation registered via `copyreg`. + This will spawn a `NumbaPickler` instance to serialize the data. + + `NumbaPickler` overrides the handling of this type so as not to spawn a + new pickler for the object when it is already being pickled by a + `NumbaPickler`. + """ + + __slots__ = "ctor", "states" + + def __init__(self, ctor, states): + self.ctor = ctor + self.states = states + + def _reduce(self): + return _CustomPickled._rebuild, (self.ctor, self.states) + + @classmethod + def _rebuild(cls, ctor, states): + return cls(ctor, states) + + +def _unpickle__CustomPickled(serialized): + """standard unpickling for `_CustomPickled`. + + Uses `NumbaPickler` to load. + """ + ctor, states = loads(serialized) + return _CustomPickled(ctor, states) + + +def _pickle__CustomPickled(cp): + """standard pickling for `_CustomPickled`. + + Uses `NumbaPickler` to dump. + """ + serialized = dumps((cp.ctor, cp.states)) + return _unpickle__CustomPickled, (serialized,) + + +# Register custom pickling for the standard pickler. +copyreg.pickle(_CustomPickled, _pickle__CustomPickled) + + +def custom_reduce(cls, states): + """For customizing object serialization in `__reduce__`. + + Object states provided here are used as keyword arguments to the + `._rebuild()` class method. + + Parameters + ---------- + states : dict + Dictionary of object states to be serialized. + + Returns + ------- + result : tuple + This tuple conforms to the return type requirement for `__reduce__`. + """ + return custom_rebuild, (_CustomPickled(cls, states),) + + +def custom_rebuild(custom_pickled): + """Customized object deserialization. + + This function is referenced internally by `custom_reduce()`. + """ + cls, states = custom_pickled.ctor, custom_pickled.states + return cls._rebuild(**states) + + +def is_serialiable(obj): + """Check if *obj* can be serialized. + + Parameters + ---------- + obj : object + + Returns + -------- + can_serialize : bool + """ + with io.BytesIO() as fout: + pickler = NumbaPickler(fout) + try: + pickler.dump(obj) + except pickle.PicklingError: + return False + else: + return True + + +def _no_pickle(obj): + raise pickle.PicklingError(f"Pickling of {type(obj)} is unsupported") + + +def disable_pickling(typ): + """This is called on a type to disable pickling""" + NumbaPickler.disabled_types.add(typ) + # Return `typ` to allow use as a decorator + return typ + + +class NumbaPickler(cloudpickle.CloudPickler): + disabled_types = set() + """A set of types that pickling cannot is disabled. + """ + + def reducer_override(self, obj): + # Overridden to disable pickling of certain types + if type(obj) in self.disabled_types: + _no_pickle(obj) # noreturn + return super().reducer_override(obj) + + +def _custom_reduce__custompickled(cp): + return cp._reduce() + + +NumbaPickler.dispatch_table[_CustomPickled] = _custom_reduce__custompickled + + +class ReduceMixin(abc.ABC): + """A mixin class for objects that should be reduced by the NumbaPickler + instead of the standard pickler. + """ + + # Subclass MUST override the below methods + + @abc.abstractmethod + def _reduce_states(self): + raise NotImplementedError + + @abc.abstractclassmethod + def _rebuild(cls, **kwargs): + raise NotImplementedError + + # Subclass can override the below methods + + def _reduce_class(self): + return self.__class__ + + # Private methods + + def __reduce__(self): + return custom_reduce(self._reduce_class(), self._reduce_states()) + + +class PickleCallableByPath: + """Wrap a callable object to be pickled by path to workaround limitation + in pickling due to non-pickleable objects in function non-locals. + + Note: + - Do not use this as a decorator. + - Wrapped object must be a global that exist in its parent module and it + can be imported by `from the_module import the_object`. + + Usage: + + >>> def my_fn(x): + >>> ... + >>> wrapped_fn = PickleCallableByPath(my_fn) + >>> # refer to `wrapped_fn` instead of `my_fn` + """ + + def __init__(self, fn): + self._fn = fn + + def __call__(self, *args, **kwargs): + return self._fn(*args, **kwargs) + + def __reduce__(self): + return type(self)._rebuild, ( + self._fn.__module__, + self._fn.__name__, + ) + + @classmethod + def _rebuild(cls, modname, fn_path): + return cls(getattr(sys.modules[modname], fn_path)) diff --git a/numba_cuda/numba/cuda/tests/core/serialize_usecases.py b/numba_cuda/numba/cuda/tests/core/serialize_usecases.py new file mode 100644 index 000000000..e96ca88c1 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/core/serialize_usecases.py @@ -0,0 +1,110 @@ +""" +Separate module with function samples for serialization tests, +to avoid issues with __main__. +""" + +import math +from math import sqrt +import numpy as np +import numpy.random as nprand + +from numba import jit +from numba.core import types + + +@jit((types.int32, types.int32)) +def add_with_sig(a, b): + return a + b + + +@jit +def add_without_sig(a, b): + return a + b + + +@jit(nopython=True) +def add_nopython(a, b): + return a + b + + +@jit(nopython=True) +def add_nopython_fail(a, b): + object() + return a + b + + +def closure(a): + @jit(nopython=True) + def inner(b, c): + return a + b + c + + return inner + + +K = 3.0 + + +def closure_with_globals(x, **jit_args): + @jit(**jit_args) + def inner(y): + # Exercise a builtin function and a module-level constant + k = max(K, K + 1) + # Exercise two functions from another module, one accessed with + # dotted notation, one imported explicitly. + return math.hypot(x, y) + sqrt(k) + + return inner + + +@jit(nopython=True) +def other_function(x, y): + return math.hypot(x, y) + + +@jit(forceobj=True) +def get_global_objmode(x): + return K * x + + +@jit(nopython=True) +def get_renamed_module(x): + nprand.seed(42) + return np.cos(x), nprand.random() + + +def closure_calling_other_function(x): + @jit(nopython=True) + def inner(y, z): + return other_function(x, y) + z + + return inner + + +def closure_calling_other_closure(x): + @jit(nopython=True) + def other_inner(y): + return math.hypot(x, y) + + @jit(nopython=True) + def inner(y): + return other_inner(y) + x + + return inner + + +# A dynamic function calling a builtin function +def _get_dyn_func(**jit_args): + code = """ + def dyn_func(x): + res = 0 + for i in range(x): + res += x + return res + """ + ns = {} + exec(code.strip(), ns) + return jit(**jit_args)(ns["dyn_func"]) + + +dyn_func = _get_dyn_func(nopython=True) +dyn_func_objmode = _get_dyn_func(forceobj=True) diff --git a/numba_cuda/numba/cuda/tests/core/test_serialize.py b/numba_cuda/numba/cuda/tests/core/test_serialize.py new file mode 100644 index 000000000..dd08f45b0 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/core/test_serialize.py @@ -0,0 +1,359 @@ +import contextlib +import gc +import pickle +import runpy +import subprocess +import sys +import unittest +from multiprocessing import get_context + +import numba +from numba.core.errors import TypingError +from numba.tests.support import TestCase +from numba.core.target_extension import resolve_dispatcher_from_str +from numba.cloudpickle import dumps, loads + + +class TestDispatcherPickling(TestCase): + def run_with_protocols(self, meth, *args, **kwargs): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + meth(proto, *args, **kwargs) + + @contextlib.contextmanager + def simulate_fresh_target(self): + hwstr = "cpu" + dispatcher_cls = resolve_dispatcher_from_str(hwstr) + old_descr = dispatcher_cls.targetdescr + # Simulate fresh targetdescr + dispatcher_cls.targetdescr = type(dispatcher_cls.targetdescr)(hwstr) + try: + yield + finally: + # Be sure to reinstantiate old descriptor, otherwise other + # objects may be out of sync. + dispatcher_cls.targetdescr = old_descr + + def check_call(self, proto, func, expected_result, args): + def check_result(func): + if isinstance(expected_result, type) and issubclass( + expected_result, Exception + ): + self.assertRaises(expected_result, func, *args) + else: + self.assertPreciseEqual(func(*args), expected_result) + + # Control + check_result(func) + pickled = pickle.dumps(func, proto) + with self.simulate_fresh_target(): + new_func = pickle.loads(pickled) + check_result(new_func) + + def test_call_with_sig(self): + from .serialize_usecases import add_with_sig + + self.run_with_protocols(self.check_call, add_with_sig, 5, (1, 4)) + # Compilation has been disabled => float inputs will be coerced to int + self.run_with_protocols(self.check_call, add_with_sig, 5, (1.2, 4.2)) + + def test_call_without_sig(self): + from .serialize_usecases import add_without_sig + + self.run_with_protocols(self.check_call, add_without_sig, 5, (1, 4)) + self.run_with_protocols( + self.check_call, add_without_sig, 5.5, (1.2, 4.3) + ) + # Object mode is enabled + self.run_with_protocols( + self.check_call, add_without_sig, "abc", ("a", "bc") + ) + + def test_call_nopython(self): + from .serialize_usecases import add_nopython + + self.run_with_protocols(self.check_call, add_nopython, 5.5, (1.2, 4.3)) + # Object mode is disabled + self.run_with_protocols( + self.check_call, add_nopython, TypingError, (object(), object()) + ) + + def test_call_nopython_fail(self): + from .serialize_usecases import add_nopython_fail + + # Compilation fails + self.run_with_protocols( + self.check_call, add_nopython_fail, TypingError, (1, 2) + ) + + def test_call_objmode_with_global(self): + from .serialize_usecases import get_global_objmode + + self.run_with_protocols( + self.check_call, get_global_objmode, 7.5, (2.5,) + ) + + def test_call_closure(self): + from .serialize_usecases import closure + + inner = closure(1) + self.run_with_protocols(self.check_call, inner, 6, (2, 3)) + + def check_call_closure_with_globals(self, **jit_args): + from .serialize_usecases import closure_with_globals + + inner = closure_with_globals(3.0, **jit_args) + self.run_with_protocols(self.check_call, inner, 7.0, (4.0,)) + + def test_call_closure_with_globals_nopython(self): + self.check_call_closure_with_globals(nopython=True) + + def test_call_closure_with_globals_objmode(self): + self.check_call_closure_with_globals(forceobj=True) + + def test_call_closure_calling_other_function(self): + from .serialize_usecases import closure_calling_other_function + + inner = closure_calling_other_function(3.0) + self.run_with_protocols(self.check_call, inner, 11.0, (4.0, 6.0)) + + def test_call_closure_calling_other_closure(self): + from .serialize_usecases import closure_calling_other_closure + + inner = closure_calling_other_closure(3.0) + self.run_with_protocols(self.check_call, inner, 8.0, (4.0,)) + + def test_call_dyn_func(self): + from .serialize_usecases import dyn_func + + # Check serializing a dynamically-created function + self.run_with_protocols(self.check_call, dyn_func, 36, (6,)) + + def test_call_dyn_func_objmode(self): + from .serialize_usecases import dyn_func_objmode + + # Same with an object mode function + self.run_with_protocols(self.check_call, dyn_func_objmode, 36, (6,)) + + def test_renamed_module(self): + from .serialize_usecases import get_renamed_module + + # Issue #1559: using a renamed module (e.g. `import numpy as np`) + # should not fail serializing + expected = get_renamed_module(0.0) + self.run_with_protocols( + self.check_call, get_renamed_module, expected, (0.0,) + ) + + def test_other_process(self): + """ + Check that reconstructing doesn't depend on resources already + instantiated in the original process. + """ + from .serialize_usecases import closure_calling_other_closure + + func = closure_calling_other_closure(3.0) + pickled = pickle.dumps(func) + code = """if 1: + import pickle + + data = {pickled!r} + func = pickle.loads(data) + res = func(4.0) + assert res == 8.0, res + """.format(**locals()) + subprocess.check_call([sys.executable, "-c", code]) + + def test_reuse(self): + """ + Check that deserializing the same function multiple times re-uses + the same dispatcher object. + + Note that "same function" is intentionally under-specified. + """ + from .serialize_usecases import closure + + func = closure(5) + pickled = pickle.dumps(func) + func2 = closure(6) + pickled2 = pickle.dumps(func2) + + f = pickle.loads(pickled) + g = pickle.loads(pickled) + h = pickle.loads(pickled2) + self.assertIs(f, g) + self.assertEqual(f(2, 3), 10) + g.disable_compile() + self.assertEqual(g(2, 4), 11) + + self.assertIsNot(f, h) + self.assertEqual(h(2, 3), 11) + + # Now make sure the original object doesn't exist when deserializing + func = closure(7) + func(42, 43) + pickled = pickle.dumps(func) + del func + gc.collect() + + f = pickle.loads(pickled) + g = pickle.loads(pickled) + self.assertIs(f, g) + self.assertEqual(f(2, 3), 12) + g.disable_compile() + self.assertEqual(g(2, 4), 13) + + def test_imp_deprecation(self): + """ + The imp module was deprecated in v3.4 in favour of importlib + """ + code = """if 1: + import pickle + import warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always', DeprecationWarning) + from numba import njit + @njit + def foo(x): + return x + 1 + foo(1) + serialized_foo = pickle.dumps(foo) + for x in w: + if 'serialize.py' in x.filename: + assert "the imp module is deprecated" not in x.msg + """ + subprocess.check_call([sys.executable, "-c", code]) + + +class TestSerializationMisc(TestCase): + def test_numba_unpickle(self): + # Test that _numba_unpickle is memorizing its output + from numba.core.serialize import _numba_unpickle + + random_obj = object() + bytebuf = pickle.dumps(random_obj) + hashed = hash(random_obj) + + got1 = _numba_unpickle(id(random_obj), bytebuf, hashed) + # not the original object + self.assertIsNot(got1, random_obj) + got2 = _numba_unpickle(id(random_obj), bytebuf, hashed) + # unpickled results are the same objects + self.assertIs(got1, got2) + + +class TestCloudPickleIssues(TestCase): + """This test case includes issues specific to the cloudpickle implementation.""" + + _numba_parallel_test_ = False + + def test_dynamic_class_reset_on_unpickle(self): + # a dynamic class + class Klass: + classvar = None + + def mutator(): + Klass.classvar = 100 + + def check(): + self.assertEqual(Klass.classvar, 100) + + saved = dumps(Klass) + mutator() + check() + loads(saved) + # Without the patch, each `loads(saved)` will reset `Klass.classvar` + check() + loads(saved) + check() + + @unittest.skipIf( + __name__ == "__main__", "Test cannot run as when module is __main__" + ) + def test_main_class_reset_on_unpickle(self): + mp = get_context("spawn") + proc = mp.Process(target=check_main_class_reset_on_unpickle) + proc.start() + proc.join(timeout=60) + self.assertEqual(proc.exitcode, 0) + + def test_dynamic_class_reset_on_unpickle_new_proc(self): + # a dynamic class + class Klass: + classvar = None + + # serialize Klass in this process + saved = dumps(Klass) + + # Check the reset problem in a new process + mp = get_context("spawn") + proc = mp.Process( + target=check_unpickle_dyn_class_new_proc, args=(saved,) + ) + proc.start() + proc.join(timeout=60) + self.assertEqual(proc.exitcode, 0) + + def test_dynamic_class_issue_7356(self): + cfunc = numba.njit(issue_7356) + self.assertEqual(cfunc(), (100, 100)) + + +class DynClass(object): + # For testing issue #7356 + a = None + + +def issue_7356(): + with numba.objmode(before="intp"): + DynClass.a = 100 + before = DynClass.a + with numba.objmode(after="intp"): + after = DynClass.a + return before, after + + +def check_main_class_reset_on_unpickle(): + # Load module and get its global dictionary + glbs = runpy.run_module( + "numba.tests.cloudpickle_main_class", + run_name="__main__", + ) + # Get the Klass and check it is from __main__ + Klass = glbs["Klass"] + assert Klass.__module__ == "__main__" + assert Klass.classvar != 100 + saved = dumps(Klass) + # mutate + Klass.classvar = 100 + # check + _check_dyn_class(Klass, saved) + + +def check_unpickle_dyn_class_new_proc(saved): + Klass = loads(saved) + assert Klass.classvar != 100 + # mutate + Klass.classvar = 100 + # check + _check_dyn_class(Klass, saved) + + +def _check_dyn_class(Klass, saved): + def check(): + if Klass.classvar != 100: + raise AssertionError("Check failed. Klass reset.") + + check() + loaded = loads(saved) + if loaded is not Klass: + raise AssertionError("Expected reuse") + # Without the patch, each `loads(saved)` will reset `Klass.classvar` + check() + loaded = loads(saved) + if loaded is not Klass: + raise AssertionError("Expected reuse") + check() + + +if __name__ == "__main__": + unittest.main()