From 4d54768ba1132109d1b7d59f19b908651c42e8c7 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 25 May 2019 00:55:27 -0700 Subject: [PATCH] Restore save/load ndarray to 1.4.1 --- src/ndarray/ndarray.cc | 11 +++++++---- tests/python/unittest/test_ndarray.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 81cf8448455c..9474d0ce40c2 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1581,6 +1581,9 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8; static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9; void NDArray::Save(dmlc::Stream *strm) const { + // TODO(junwu): Support this after NumPy operators are merged + CHECK(!Imperative::Get()->is_np_comp()) + << "Saving ndarray within the scope of np_shape is not supported."; // write magic number to mark this version // for storage type strm->Write(NDARRAY_V2_MAGIC); @@ -1698,6 +1701,9 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { } bool NDArray::Load(dmlc::Stream *strm) { + // TODO(junwu): Support this after NumPy operators are merged + CHECK(!Imperative::Get()->is_np_comp()) + << "Loading ndarray within the scope of np_shape is not supported."; uint32_t magic; if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; if (magic != NDARRAY_V2_MAGIC) { @@ -1718,10 +1724,7 @@ bool NDArray::Load(dmlc::Stream *strm) { // load shape mxnet::TShape shape; if (!shape.Load(strm)) return false; - if (!Imperative::Get()->is_np_comp()) { - common::ConvertToNumpyShape(&shape); - } - if (mxnet::op::shape_is_none(shape)) { + if (shape.ndim() == 0) { *this = NDArray(); return true; } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index df505436fa0c..8998b215d704 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1701,6 +1701,20 @@ def test_zero_from_numpy(): assert False +@with_seed() +def test_save_load_zero_size_ndarrays(): + shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)] + array_list = [np.random.randint(0, 10, size=shape) for shape in shapes] + array_list = [mx.nd.array(arr) for arr in array_list] + with TemporaryDirectory() as work_dir: + fname = os.path.join(work_dir, 'dataset') + mx.nd.save(fname, array_list) + array_list_loaded = mx.nd.load(fname) + assert len(array_list) == len(array_list_loaded) + for a1, a2 in zip(array_list, array_list_loaded): + assert np.array_equal(a1.asnumpy(), a2.asnumpy()) + + if __name__ == '__main__': import nose nose.runmodule()