diff --git a/Modules/Bridge/NumPy/wrapping/PyBuffer.i.init b/Modules/Bridge/NumPy/wrapping/PyBuffer.i.init index 892267faa54..545e02e8d98 100644 --- a/Modules/Bridge/NumPy/wrapping/PyBuffer.i.init +++ b/Modules/Bridge/NumPy/wrapping/PyBuffer.i.init @@ -15,6 +15,16 @@ try: if obj is None: return self.itk_base = getattr(obj, 'itk_base', None) + def __reduce_ex__(self, protocol): + np_copy = np.array(self, copy=True) + return type(self)._reconstruct, (np_copy,), None + + @classmethod + def _reconstruct(cls, obj): + with memoryview(obj) as m: + obj = m.obj + if type(obj) is np.ndarray: + return cls(obj, itk_base=None) except ImportError: HAVE_NUMPY = False diff --git a/Modules/Bridge/NumPy/wrapping/test/itkPyBufferTest.py b/Modules/Bridge/NumPy/wrapping/test/itkPyBufferTest.py index ebb04aa5e98..bee7f6f86bf 100644 --- a/Modules/Bridge/NumPy/wrapping/test/itkPyBufferTest.py +++ b/Modules/Bridge/NumPy/wrapping/test/itkPyBufferTest.py @@ -29,6 +29,35 @@ class TestNumpyITKMemoryviewInterface(unittest.TestCase): def setUp(self): pass + def test_NDArrayITKBase_pickle(self): + """ + Test the serialization of itk.NDArrayITKBase + """ + Dimension = 3 + ScalarImageType = itk.Image[itk.UC, Dimension] + RegionType = itk.ImageRegion[Dimension] + + region = RegionType() + region.SetSize(0, 6); + region.SetSize(1, 6); + region.SetSize(2, 6); + + scalarImage = ScalarImageType.New() + scalarImage.SetRegions(region); + scalarImage.Allocate(True); + scalarImage.SetPixel([0, 0, 0], 5) + scalarImage.SetPixel([0, 0, 1], 3) + scalarImage.SetPixel([5, 5, 5], 8) + ndarray_itk_base = itk.array_view_from_image(scalarImage) + + import pickle + + ## test serialization of itk ndarrary itk base + pickled = pickle.dumps(ndarray_itk_base) + reloaded = pickle.loads(pickled) + equal = (reloaded == ndarray_itk_base).all() + assert equal, 'Different results before and after pickle' + def test_NumPyBridge_itkScalarImage(self): "Try to convert all pixel types to NumPy array view"