diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 68412267da6b..cb6ae5a164b9 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -74,6 +74,38 @@ def reduce_ndarray(data): ForkingPickler.register(nd.NDArray, reduce_ndarray) +if sys.platform == 'darwin' or sys.platform == 'win32': + def rebuild_np_ndarray(*args): + """Rebuild ndarray from pickled shared memory""" + # pylint: disable=no-value-for-parameter + return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(*args)) + + def reduce_np_ndarray(data): + """Reduce ndarray to shared memory handle""" + return rebuild_np_ndarray, data._to_shared_mem() +else: + def rebuild_np_ndarray(pid, fd, shape, dtype): + """Rebuild ndarray from pickled shared memory""" + # pylint: disable=no-value-for-parameter + if sys.version_info[0] == 2: + fd = multiprocessing.reduction.rebuild_handle(fd) + else: + fd = fd.detach() + return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype)) + + def reduce_np_ndarray(data): + """Reduce ndarray to shared memory handle""" + # keep a local ref before duplicating fd + data = data.as_in_context(context.Context('cpu_shared', 0)) + pid, fd, shape, dtype = data._to_shared_mem() + if sys.version_info[0] == 2: + fd = multiprocessing.reduction.reduce_handle(fd) + else: + fd = multiprocessing.reduction.DupFd(fd) + return rebuild_np_ndarray, (pid, fd, shape, dtype) + +ForkingPickler.register(_mx_np.ndarray, reduce_np_ndarray) + class ConnectionWrapper(object): """Connection wrapper for multiprocessing that supports sending diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index b6692c009c18..6077f4df13ae 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1098,6 +1098,18 @@ def test_np_get_dtype(): assert type(mx_arr.dtype) == type(np_arr.dtype) +@use_np +def test_np_ndarray_pickle(): + a = np.random.uniform(size=(4, 5)) + a_copy = a.copy() + import pickle + with open("np_ndarray_pickle_test_file", 'wb') as f: + pickle.dump(a_copy, f) + with open("np_ndarray_pickle_test_file", 'rb') as f: + a_load = pickle.load(f) + same(a.asnumpy(), a_load.asnumpy()) + + if __name__ == '__main__': import nose nose.runmodule()