Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion numba_cuda/numba/cuda/codegen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from llvmlite import ir

from numba.core import config, serialize
from numba.core import config
from numba.cuda import serialize
from .cudadrv import devices, driver, nvvm, runtime, nvrtc
from numba.cuda.core.codegen import Codegen, CodeLibrary
from numba.cuda.cudadrv.libs import get_cudalib
Expand Down
4 changes: 2 additions & 2 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@


from numba import mviewbuf
from numba.core import serialize, config
from numba.cuda import utils
from numba.core import config
from numba.cuda import utils, serialize
from .error import CudaSupportError, CudaDriverError
from .drvapi import API_PROTOTYPES
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
Expand Down
4 changes: 2 additions & 2 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import weakref
import uuid

from numba.core import compiler, serialize, sigutils, types, typing, config
from numba.cuda import utils
from numba.core import compiler, sigutils, types, typing, config
from numba.cuda import serialize, utils
from numba.cuda.core.caching import Cache, CacheImpl, NullCache
from numba.core.compiler_lock import global_compiler_lock
from numba.core.dispatcher import _DispatcherBase
Expand Down
264 changes: 264 additions & 0 deletions numba_cuda/numba/cuda/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""
Serialization support for compiled functions.
"""

import sys
import abc
import io
import copyreg


import pickle
from numba import cloudpickle
from llvmlite import ir


#
# Pickle support
#


def _rebuild_reduction(cls, *args):
"""
Global hook to rebuild a given class from its __reduce__ arguments.
"""
return cls._rebuild(*args)


# Keep unpickled object via `numba_unpickle` alive.
_unpickled_memo = {}


def _numba_unpickle(address, bytedata, hashed):
"""Used by `numba_unpickle` from _helperlib.c

Parameters
----------
address : int
bytedata : bytes
hashed : bytes

Returns
-------
obj : object
unpickled object
"""
key = (address, hashed)
try:
obj = _unpickled_memo[key]
except KeyError:
_unpickled_memo[key] = obj = cloudpickle.loads(bytedata)
return obj


def dumps(obj):
"""Similar to `pickle.dumps()`. Returns the serialized object in bytes."""
pickler = NumbaPickler
with io.BytesIO() as buf:
p = pickler(buf, protocol=4)
p.dump(obj)
pickled = buf.getvalue()

return pickled


def runtime_build_excinfo_struct(static_exc, exc_args):
exc, static_args, locinfo = cloudpickle.loads(static_exc)
real_args = []
exc_args_iter = iter(exc_args)
for arg in static_args:
if isinstance(arg, ir.Value):
real_args.append(next(exc_args_iter))
else:
real_args.append(arg)
return (exc, tuple(real_args), locinfo)


# Alias to pickle.loads to allow `serialize.loads()`
loads = cloudpickle.loads


class _CustomPickled:
"""A wrapper for objects that must be pickled with `NumbaPickler`.

Standard `pickle` will pick up the implementation registered via `copyreg`.
This will spawn a `NumbaPickler` instance to serialize the data.

`NumbaPickler` overrides the handling of this type so as not to spawn a
new pickler for the object when it is already being pickled by a
`NumbaPickler`.
"""

__slots__ = "ctor", "states"

def __init__(self, ctor, states):
self.ctor = ctor
self.states = states

def _reduce(self):
return _CustomPickled._rebuild, (self.ctor, self.states)

@classmethod
def _rebuild(cls, ctor, states):
return cls(ctor, states)


def _unpickle__CustomPickled(serialized):
"""standard unpickling for `_CustomPickled`.

Uses `NumbaPickler` to load.
"""
ctor, states = loads(serialized)
return _CustomPickled(ctor, states)


def _pickle__CustomPickled(cp):
"""standard pickling for `_CustomPickled`.

Uses `NumbaPickler` to dump.
"""
serialized = dumps((cp.ctor, cp.states))
return _unpickle__CustomPickled, (serialized,)


# Register custom pickling for the standard pickler.
copyreg.pickle(_CustomPickled, _pickle__CustomPickled)


def custom_reduce(cls, states):
"""For customizing object serialization in `__reduce__`.

Object states provided here are used as keyword arguments to the
`._rebuild()` class method.

Parameters
----------
states : dict
Dictionary of object states to be serialized.

Returns
-------
result : tuple
This tuple conforms to the return type requirement for `__reduce__`.
"""
return custom_rebuild, (_CustomPickled(cls, states),)


def custom_rebuild(custom_pickled):
"""Customized object deserialization.

This function is referenced internally by `custom_reduce()`.
"""
cls, states = custom_pickled.ctor, custom_pickled.states
return cls._rebuild(**states)


def is_serialiable(obj):
"""Check if *obj* can be serialized.

Parameters
----------
obj : object

Returns
--------
can_serialize : bool
"""
with io.BytesIO() as fout:
pickler = NumbaPickler(fout)
try:
pickler.dump(obj)
except pickle.PicklingError:
return False
else:
return True


def _no_pickle(obj):
raise pickle.PicklingError(f"Pickling of {type(obj)} is unsupported")


def disable_pickling(typ):
"""This is called on a type to disable pickling"""
NumbaPickler.disabled_types.add(typ)
# Return `typ` to allow use as a decorator
return typ


class NumbaPickler(cloudpickle.CloudPickler):
disabled_types = set()
"""A set of types that pickling cannot is disabled.
"""

def reducer_override(self, obj):
# Overridden to disable pickling of certain types
if type(obj) in self.disabled_types:
_no_pickle(obj) # noreturn
return super().reducer_override(obj)


def _custom_reduce__custompickled(cp):
return cp._reduce()


NumbaPickler.dispatch_table[_CustomPickled] = _custom_reduce__custompickled


class ReduceMixin(abc.ABC):
"""A mixin class for objects that should be reduced by the NumbaPickler
instead of the standard pickler.
"""

# Subclass MUST override the below methods

@abc.abstractmethod
def _reduce_states(self):
raise NotImplementedError

@abc.abstractclassmethod
def _rebuild(cls, **kwargs):
raise NotImplementedError

# Subclass can override the below methods

def _reduce_class(self):
return self.__class__

# Private methods

def __reduce__(self):
return custom_reduce(self._reduce_class(), self._reduce_states())


class PickleCallableByPath:
"""Wrap a callable object to be pickled by path to workaround limitation
in pickling due to non-pickleable objects in function non-locals.

Note:
- Do not use this as a decorator.
- Wrapped object must be a global that exist in its parent module and it
can be imported by `from the_module import the_object`.

Usage:

>>> def my_fn(x):
>>> ...
>>> wrapped_fn = PickleCallableByPath(my_fn)
>>> # refer to `wrapped_fn` instead of `my_fn`
"""

def __init__(self, fn):
self._fn = fn

def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)

def __reduce__(self):
return type(self)._rebuild, (
self._fn.__module__,
self._fn.__name__,
)

@classmethod
def _rebuild(cls, modname, fn_path):
return cls(getattr(sys.modules[modname], fn_path))
Loading