From 0a30c6085254440738708155192dcaa3268fc869 Mon Sep 17 00:00:00 2001 From: Zhiyuan Liu Date: Thu, 30 Jul 2020 17:55:16 -0400 Subject: [PATCH] BUG: Solve an issue mentioned in PR #1942 --- Modules/Bridge/NumPy/wrapping/PyBuffer.i.init | 16 ++++++++ .../NumPy/wrapping/test/itkPyBufferTest.py | 39 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/Modules/Bridge/NumPy/wrapping/PyBuffer.i.init b/Modules/Bridge/NumPy/wrapping/PyBuffer.i.init index 4f8b2250cd4..95935a39941 100644 --- a/Modules/Bridge/NumPy/wrapping/PyBuffer.i.init +++ b/Modules/Bridge/NumPy/wrapping/PyBuffer.i.init @@ -17,6 +17,22 @@ try: except ImportError: HAVE_NUMPY = False +try: + import numpy as np + from distributed.protocol import dask_serialize, dask_deserialize + from typing import Dict, List, Tuple +except ImportError: + pass +else: + @dask_serialize.register(NDArrayITKBase) + def serialize(ndarray_itk_base: NDArrayITKBase) -> Tuple[Dict, List[bytes]]: + dumps = dask_serialize.dispatch(np.ndarray) + return dumps(ndarray_itk_base) + + @dask_deserialize.register(NDArrayITKBase) + def deserialize(header: Dict, frames: List[bytes]) -> NDArrayITKBase: + loads = dask_deserialize.dispatch(np.ndarray) + return NDArrayITKBase(loads(header, frames)) def _get_numpy_pixelid(itk_Image_type): """Returns a ITK PixelID given a numpy array.""" diff --git a/Modules/Bridge/NumPy/wrapping/test/itkPyBufferTest.py b/Modules/Bridge/NumPy/wrapping/test/itkPyBufferTest.py index ebb04aa5e98..1df27edd940 100644 --- a/Modules/Bridge/NumPy/wrapping/test/itkPyBufferTest.py +++ b/Modules/Bridge/NumPy/wrapping/test/itkPyBufferTest.py @@ -28,6 +28,45 @@ class TestNumpyITKMemoryviewInterface(unittest.TestCase): def setUp(self): pass + def test_NDArrayITKBase_pickle(self): + """ + Test the serialization of itk.NDArrayITKBase + """ + Dimension = 3 + ScalarImageType = itk.Image[itk.UC, Dimension] + RegionType = itk.ImageRegion[Dimension] + + region = RegionType() + region.SetSize(0, 6); + region.SetSize(1, 6); + region.SetSize(2, 6); + + scalarImage = ScalarImageType.New() + scalarImage.SetRegions(region); + scalarImage.Allocate(True); + scalarImage.SetPixel([0, 0, 0], 5) + scalarImage.SetPixel([0, 0, 1], 3) + scalarImage.SetPixel([5, 5, 5], 8) + ndarray_itk_base = itk.array_view_from_image(scalarImage) + + import pickle + + ## test serialization of itk ndarrary itk base + pickled = pickle.dumps(ndarray_itk_base) + reloaded = pickle.loads(pickled) + equal = (reloaded == ndarray_itk_base).all() + assert equal, 'Different results before and after pickle' + + try: + import dask + from distributed.protocol.serialize import dask_dumps, dask_loads + except ImportError: + pass + else: + header, frames = dask_dumps(ndarray_itk_base) + recon_obj = dask_loads(header, frames) + equal = (recon_obj == ndarray_itk_base).all() + assert equal, 'Different results before and after pickle' def test_NumPyBridge_itkScalarImage(self): "Try to convert all pixel types to NumPy array view"