diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index df06b5c19641..af9e7a7170d3 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1200,6 +1200,7 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op << "from.shape = " << from.shape() << " to.shape=" << to.shape(); CHECK(!mxnet::op::shape_is_none(from.shape())) << "source operands have undefined shape"; + if (from.shape().Size() == 0U) return; // important: callback must always capture by value const Context from_ctx = from.ctx(); const int a = from_ctx.dev_mask(); @@ -1580,19 +1581,20 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8; /* magic number for ndarray version 2, with storage type */ static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9; +// magic number for ndarray version 3, with np shape semantics. +// The ndarray must be saved and loaded within np shape semantics. +static const uint32_t NDARRAY_V3_MAGIC = 0xF993faca; + void NDArray::Save(dmlc::Stream *strm) const { - // TODO(junwu): Support this after NumPy operators are merged if (Imperative::Get()->is_np_shape()) { CHECK_EQ(storage_type(), kDefaultStorage) - << "only allow serializing ndarray of default storage type within the scope of np_shape"; - CHECK_NE(shape_.Size(), 0U) - << "serializing zero-size ndarray within the scope of np_shape is not supported"; - CHECK_NE(shape_.ndim(), 0) - << "serializing scalar ndarray within the scope of np_shape is not supported"; + << "only allow serializing ndarray of default storage type in np shape semantics"; + strm->Write(NDARRAY_V3_MAGIC); + } else { + // write magic number to mark this version + // for storage type + strm->Write(NDARRAY_V2_MAGIC); } - // write magic number to mark this version - // for storage type - strm->Write(NDARRAY_V2_MAGIC); // save storage type int32_t stype = storage_type(); @@ -1709,13 +1711,28 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { bool NDArray::Load(dmlc::Stream *strm) { uint32_t magic; if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; - if (magic != NDARRAY_V2_MAGIC) { + if (magic == NDARRAY_V3_MAGIC) { + CHECK(Imperative::Get()->is_np_shape()) + << "ndarray was saved in np shape semantics, must be loaded in the same semantics." + " Please turn on np shape semantics in Python using `with np_shape(True)`" + " or decorator `use_np_shape` to scope the code of loading the ndarray."; + } else { + CHECK(!Imperative::Get()->is_np_shape()) + << "ndarray was not saved in np shape semantics, but being loaded in np shape semantics." + " Please turn off np shape semantics in Python using `with np_shape(False)`" + " to scope of the code of loading the ndarray."; + } + if (magic != NDARRAY_V2_MAGIC && magic != NDARRAY_V3_MAGIC) { return LegacyLoad(strm, magic); } // load storage type int32_t stype; if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false; + if (Imperative::Get()->is_np_shape()) { + CHECK_EQ(stype, kDefaultStorage) + << "only allow deserializing ndarray of default storage type in np shape semantics"; + } const int32_t nad = num_aux_data(static_cast(stype)); // load storage shape @@ -1727,16 +1744,12 @@ bool NDArray::Load(dmlc::Stream *strm) { // load shape mxnet::TShape shape; if (!shape.Load(strm)) return false; - // TODO(junwu): Support this after NumPy operators are merged if (Imperative::Get()->is_np_shape()) { - CHECK_EQ(stype, kDefaultStorage) - << "only allow deserializing ndarray of default storage type within the scope of np_shape"; - CHECK_NE(shape.Size(), 0U) - << "deserializing zero-size ndarray within the scope of np_shape is not supported"; - CHECK_NE(shape.ndim(), 0) - << "deserializing scalar ndarray within the scope of np_shape is not supported"; - } - if (shape.ndim() == 0) { + if (!shape_is_known(shape)) { + *this = NDArray(); + return true; + } + } else if (shape.ndim() == 0) { *this = NDArray(); return true; } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index fd23548c64e7..e5315900c725 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1702,26 +1702,30 @@ def test_zero_from_numpy(): @with_seed() -def test_save_load_zero_size_ndarrays(): - def check_save_load(is_np_shape, shapes, throw_exception): - with mx.np_shape(is_np_shape): +def test_save_load_scalar_zero_size_ndarrays(): + def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_exception, load_throw_exception): + with mx.np_shape(save_is_np_shape): array_list = [np.random.randint(0, 10, size=shape) for shape in shapes] array_list = [mx.nd.array(arr) for arr in array_list] with TemporaryDirectory() as work_dir: fname = os.path.join(work_dir, 'dataset') - if throw_exception: + if save_throw_exception: assert_exception(mx.nd.save, mx.MXNetError, fname, array_list) else: mx.nd.save(fname, array_list) - array_list_loaded = mx.nd.load(fname) - assert len(array_list) == len(array_list_loaded) - for a1, a2 in zip(array_list, array_list_loaded): - assert np.array_equal(a1.asnumpy(), a2.asnumpy()) - - check_save_load(False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False) # legacy mode - check_save_load(True, [(2, 1), (3, 5)], False) # np_shape semantics, no zero-size, should succeed - check_save_load(True, [(2, 1), (3, 0)], True) # np_shape semantics, zero-size, should fail - check_save_load(True, [(2, 1), ()], True) # np_shape semantics, scalar tensor, should fail + with mx.np_shape(load_is_np_shape): + if load_throw_exception: + assert_exception(mx.nd.load, mx.MXNetError, fname) + else: + array_list_loaded = mx.nd.load(fname) + assert len(array_list) == len(array_list_loaded) + for a1, a2 in zip(array_list, array_list_loaded): + assert np.array_equal(a1.asnumpy(), a2.asnumpy()) + + check_save_load(False, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False) + check_save_load(True, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True) + check_save_load(False, True, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True) + check_save_load(True, True, [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False) if __name__ == '__main__':