Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
pickler override for np ndarrays
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Oct 21, 2019
1 parent 746cbc5 commit 6c7e1e7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
32 changes: 32 additions & 0 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,38 @@ def reduce_ndarray(data):

ForkingPickler.register(nd.NDArray, reduce_ndarray)

if sys.platform == 'darwin' or sys.platform == 'win32':
def rebuild_ndarray(*args):
"""Rebuild ndarray from pickled shared memory"""
# pylint: disable=no-value-for-parameter
return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(*args))

def reduce_ndarray(data):
"""Reduce ndarray to shared memory handle"""
return rebuild_ndarray, data._to_shared_mem()
else:
def rebuild_ndarray(pid, fd, shape, dtype):
"""Rebuild ndarray from pickled shared memory"""
# pylint: disable=no-value-for-parameter
if sys.version_info[0] == 2:
fd = multiprocessing.reduction.rebuild_handle(fd)
else:
fd = fd.detach()
return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype))

def reduce_ndarray(data):
"""Reduce ndarray to shared memory handle"""
# keep a local ref before duplicating fd
data = data.as_in_context(context.Context('cpu_shared', 0))
pid, fd, shape, dtype = data._to_shared_mem()
if sys.version_info[0] == 2:
fd = multiprocessing.reduction.reduce_handle(fd)
else:
fd = multiprocessing.reduction.DupFd(fd)
return rebuild_ndarray, (pid, fd, shape, dtype)

ForkingPickler.register(_mx_np.ndarray, reduce_ndarray)


class ConnectionWrapper(object):
"""Connection wrapper for multiprocessing that supports sending
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,18 @@ def test_np_get_dtype():
assert type(mx_arr.dtype) == type(np_arr.dtype)


@use_np
def test_np_ndarray_pickle():
a = np.random.uniform(size=(4, 5))
a_copy = a.copy()
import pickle
with open("np_ndarray_pickle_test_file", 'wb') as f:
pickle.dump(a_copy, f)
with open("np_ndarray_pickle_test_file", 'rb') as f:
a_load = pickle.load(f)
same(a.asnumpy(), a_load.asnumpy())


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 6c7e1e7

Please sign in to comment.