diff --git a/dill/_dill.py b/dill/_dill.py index 55e2ba14..3b7cdb19 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -1189,6 +1189,39 @@ def _create_namedtuple(name, fieldnames, modulename, defaults=None): t = collections.namedtuple(name, fieldnames, defaults=defaults, module=modulename) return t +def _create_capsule(pointer, name, context, destructor): + attr_found = False + try: + # based on https://github.com/python/cpython/blob/f4095e53ab708d95e019c909d5928502775ba68f/Objects/capsule.c#L209-L231 + if PY3: + uname = name.decode('utf8') + else: + uname = name + for i in range(1, uname.count('.')+1): + names = uname.rsplit('.', i) + try: + module = __import__(names[0]) + except: + pass + obj = module + for attr in names[1:]: + obj = getattr(obj, attr) + capsule = obj + attr_found = True + break + except: + pass + + if attr_found: + if _PyCapsule_IsValid(capsule, name): + return capsule + raise UnpicklingError("%s object exists at %s but a PyCapsule object was expected." % (type(capsule), name)) + else: + warnings.warn('Creating a new PyCapsule %s for a C data structure that may not be present in memory. Segmentation faults or other memory errors are possible.' % (name,), UnpicklingWarning) + capsule = _PyCapsule_New(pointer, name, destructor) + _PyCapsule_SetContext(capsule, context) + return capsule + def _getattr(objclass, name, repr_str): # hack to grab the reference directly try: #XXX: works only for __builtin__ ? @@ -2177,6 +2210,52 @@ def save_function(pickler, obj): log.info("# F2") return +if HAS_CTYPES and hasattr(ctypes, 'pythonapi'): + _PyCapsule_New = ctypes.pythonapi.PyCapsule_New + _PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p) + _PyCapsule_New.restype = ctypes.py_object + _PyCapsule_GetPointer = ctypes.pythonapi.PyCapsule_GetPointer + _PyCapsule_GetPointer.argtypes = (ctypes.py_object, ctypes.c_char_p) + _PyCapsule_GetPointer.restype = ctypes.c_void_p + _PyCapsule_GetDestructor = ctypes.pythonapi.PyCapsule_GetDestructor + _PyCapsule_GetDestructor.argtypes = (ctypes.py_object,) + _PyCapsule_GetDestructor.restype = ctypes.c_void_p + _PyCapsule_GetContext = ctypes.pythonapi.PyCapsule_GetContext + _PyCapsule_GetContext.argtypes = (ctypes.py_object,) + _PyCapsule_GetContext.restype = ctypes.c_void_p + _PyCapsule_GetName = ctypes.pythonapi.PyCapsule_GetName + _PyCapsule_GetName.argtypes = (ctypes.py_object,) + _PyCapsule_GetName.restype = ctypes.c_char_p + _PyCapsule_IsValid = ctypes.pythonapi.PyCapsule_IsValid + _PyCapsule_IsValid.argtypes = (ctypes.py_object, ctypes.c_char_p) + _PyCapsule_IsValid.restype = ctypes.c_bool + _PyCapsule_SetContext = ctypes.pythonapi.PyCapsule_SetContext + _PyCapsule_SetContext.argtypes = (ctypes.py_object, ctypes.c_void_p) + _PyCapsule_SetDestructor = ctypes.pythonapi.PyCapsule_SetDestructor + _PyCapsule_SetDestructor.argtypes = (ctypes.py_object, ctypes.c_void_p) + _PyCapsule_SetName = ctypes.pythonapi.PyCapsule_SetName + _PyCapsule_SetName.argtypes = (ctypes.py_object, ctypes.c_char_p) + _PyCapsule_SetPointer = ctypes.pythonapi.PyCapsule_SetPointer + _PyCapsule_SetPointer.argtypes = (ctypes.py_object, ctypes.c_void_p) + _testcapsule = _PyCapsule_New( + ctypes.cast(_PyCapsule_New, ctypes.c_void_p), + ctypes.create_string_buffer(b'dill._dill._testcapsule'), + None + ) + PyCapsuleType = type(_testcapsule) + @register(PyCapsuleType) + def save_capsule(pickler, obj): + log.info("Cap: %s", obj) + name = _PyCapsule_GetName(obj) + warnings.warn('Pickling a PyCapsule (%s) does not pickle any C data structures and could cause segmentation faults or other memory errors when unpickling.' % (name,), PicklingWarning) + pointer = _PyCapsule_GetPointer(obj, name) + context = _PyCapsule_GetContext(obj) + destructor = _PyCapsule_GetDestructor(obj) + pickler.save_reduce(_create_capsule, (pointer, name, context, destructor), obj=obj) + log.info("# Cap") +else: + _testcapsule = None + # quick sanity checking def pickles(obj,exact=False,safe=False,**kwds): """ diff --git a/dill/_objects.py b/dill/_objects.py index 8b1cb65c..31823636 100644 --- a/dill/_objects.py +++ b/dill/_objects.py @@ -548,6 +548,11 @@ class _Struct(ctypes.Structure): else: x['BufferType'] = buffer('') +from dill._dill import _testcapsule +if _testcapsule is not None: + x['PyCapsuleType'] = _testcapsule +del _testcapsule + # -- cleanup ---------------------------------------------------------------- a.update(d) # registered also succeed if sys.platform[:3] == 'win': diff --git a/dill/_shims.py b/dill/_shims.py index 6bda5136..6e170437 100644 --- a/dill/_shims.py +++ b/dill/_shims.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) -# Author: Anirudh Vegesana (avegesan@stanford.edu) +# Author: Anirudh Vegesana (avegesan@cs.stanford.edu) # Copyright (c) 2021-2022 The Uncertainty Quantification Foundation. # License: 3-clause BSD. The full license text is available at: # - https://github.com/uqfoundation/dill/blob/master/LICENSE diff --git a/tests/test_dictviews.py b/tests/test_dictviews.py index 3bbc5d62..213e7ab2 100644 --- a/tests/test_dictviews.py +++ b/tests/test_dictviews.py @@ -1,8 +1,8 @@ #!/usr/bin/env python # # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) -# Copyright (c) 2008-2016 California Institute of Technology. -# Copyright (c) 2016-2021 The Uncertainty Quantification Foundation. +# Author: Anirudh Vegesana (avegesan@cs.stanford.edu) +# Copyright (c) 2021-2022 The Uncertainty Quantification Foundation. # License: 3-clause BSD. The full license text is available at: # - https://github.com/uqfoundation/dill/blob/master/LICENSE diff --git a/tests/test_objects.py b/tests/test_objects.py index 985041be..c83060c3 100644 --- a/tests/test_objects.py +++ b/tests/test_objects.py @@ -57,6 +57,5 @@ def test_objects(): #pickles(member, exact=True) pickles(member, exact=False) - if __name__ == '__main__': test_objects() diff --git a/tests/test_pycapsule.py b/tests/test_pycapsule.py new file mode 100644 index 00000000..6e115ffd --- /dev/null +++ b/tests/test_pycapsule.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# +# Author: Mike McKerns (mmckerns @caltech and @uqfoundation) +# Author: Anirudh Vegesana (avegesan@cs.stanford.edu) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE +""" +test pickling a PyCapsule object +""" + +import dill +import warnings + +test_pycapsule = None + +if dill._dill._testcapsule is not None: + import ctypes + def test_pycapsule(): + name = ctypes.create_string_buffer(b'dill._testcapsule') + capsule = dill._dill._PyCapsule_New( + ctypes.cast(dill._dill._PyCapsule_New, ctypes.c_void_p), + name, + None + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dill.copy(capsule) + dill._testcapsule = capsule + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dill.copy(capsule) + dill._testcapsule = None + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", dill.PicklingWarning) + dill.copy(capsule) + except dill.UnpicklingError: + pass + else: + raise AssertionError("Expected a different error") + +if __name__ == '__main__': + if test_pycapsule is not None: + test_pycapsule()