From e8a20fbb15fb1e8f8cb89c28ccdbc264b661fde6 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 1 Jun 2019 11:50:46 -0700 Subject: [PATCH] Enable serializing/deserializing ndarrays in np_shape semantics (#15090) * Loosen the contraint on serializing/deserializing ndarrays within the scope of np_shape * Support save/load dense ndarrays in np_shape semantics --- src/ndarray/ndarray.cc | 47 ++++++++++++++++++++------- tests/python/unittest/test_ndarray.py | 35 +++++++++++++------- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 16c579fefa32..af9e7a7170d3 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1200,6 +1200,7 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op << "from.shape = " << from.shape() << " to.shape=" << to.shape(); CHECK(!mxnet::op::shape_is_none(from.shape())) << "source operands have undefined shape"; + if (from.shape().Size() == 0U) return; // important: callback must always capture by value const Context from_ctx = from.ctx(); const int a = from_ctx.dev_mask(); @@ -1580,13 +1581,20 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8; /* magic number for ndarray version 2, with storage type */ static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9; +// magic number for ndarray version 3, with np shape semantics. +// The ndarray must be saved and loaded within np shape semantics. +static const uint32_t NDARRAY_V3_MAGIC = 0xF993faca; + void NDArray::Save(dmlc::Stream *strm) const { - // TODO(junwu): Support this after NumPy operators are merged - CHECK(!Imperative::Get()->is_np_shape()) - << "Saving ndarray within the scope of np_shape is not supported."; - // write magic number to mark this version - // for storage type - strm->Write(NDARRAY_V2_MAGIC); + if (Imperative::Get()->is_np_shape()) { + CHECK_EQ(storage_type(), kDefaultStorage) + << "only allow serializing ndarray of default storage type in np shape semantics"; + strm->Write(NDARRAY_V3_MAGIC); + } else { + // write magic number to mark this version + // for storage type + strm->Write(NDARRAY_V2_MAGIC); + } // save storage type int32_t stype = storage_type(); @@ -1701,18 +1709,30 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { } bool NDArray::Load(dmlc::Stream *strm) { - // TODO(junwu): Support this after NumPy operators are merged - CHECK(!Imperative::Get()->is_np_shape()) - << "Loading ndarray within the scope of np_shape is not supported."; uint32_t magic; if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; - if (magic != NDARRAY_V2_MAGIC) { + if (magic == NDARRAY_V3_MAGIC) { + CHECK(Imperative::Get()->is_np_shape()) + << "ndarray was saved in np shape semantics, must be loaded in the same semantics." + " Please turn on np shape semantics in Python using `with np_shape(True)`" + " or decorator `use_np_shape` to scope the code of loading the ndarray."; + } else { + CHECK(!Imperative::Get()->is_np_shape()) + << "ndarray was not saved in np shape semantics, but being loaded in np shape semantics." + " Please turn off np shape semantics in Python using `with np_shape(False)`" + " to scope of the code of loading the ndarray."; + } + if (magic != NDARRAY_V2_MAGIC && magic != NDARRAY_V3_MAGIC) { return LegacyLoad(strm, magic); } // load storage type int32_t stype; if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false; + if (Imperative::Get()->is_np_shape()) { + CHECK_EQ(stype, kDefaultStorage) + << "only allow deserializing ndarray of default storage type in np shape semantics"; + } const int32_t nad = num_aux_data(static_cast(stype)); // load storage shape @@ -1724,7 +1744,12 @@ bool NDArray::Load(dmlc::Stream *strm) { // load shape mxnet::TShape shape; if (!shape.Load(strm)) return false; - if (shape.ndim() == 0) { + if (Imperative::Get()->is_np_shape()) { + if (!shape_is_known(shape)) { + *this = NDArray(); + return true; + } + } else if (shape.ndim() == 0) { *this = NDArray(); return true; } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 8b2a270a34a2..e5315900c725 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1702,17 +1702,30 @@ def test_zero_from_numpy(): @with_seed() -def test_save_load_zero_size_ndarrays(): - shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)] - array_list = [np.random.randint(0, 10, size=shape) for shape in shapes] - array_list = [mx.nd.array(arr) for arr in array_list] - with TemporaryDirectory() as work_dir: - fname = os.path.join(work_dir, 'dataset') - mx.nd.save(fname, array_list) - array_list_loaded = mx.nd.load(fname) - assert len(array_list) == len(array_list_loaded) - for a1, a2 in zip(array_list, array_list_loaded): - assert np.array_equal(a1.asnumpy(), a2.asnumpy()) +def test_save_load_scalar_zero_size_ndarrays(): + def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_exception, load_throw_exception): + with mx.np_shape(save_is_np_shape): + array_list = [np.random.randint(0, 10, size=shape) for shape in shapes] + array_list = [mx.nd.array(arr) for arr in array_list] + with TemporaryDirectory() as work_dir: + fname = os.path.join(work_dir, 'dataset') + if save_throw_exception: + assert_exception(mx.nd.save, mx.MXNetError, fname, array_list) + else: + mx.nd.save(fname, array_list) + with mx.np_shape(load_is_np_shape): + if load_throw_exception: + assert_exception(mx.nd.load, mx.MXNetError, fname) + else: + array_list_loaded = mx.nd.load(fname) + assert len(array_list) == len(array_list_loaded) + for a1, a2 in zip(array_list, array_list_loaded): + assert np.array_equal(a1.asnumpy(), a2.asnumpy()) + + check_save_load(False, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False) + check_save_load(True, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True) + check_save_load(False, True, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True) + check_save_load(True, True, [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False) if __name__ == '__main__':