Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Loosen the contraint on serializing/deserializing ndarrays within the…
Browse files Browse the repository at this point in the history
… scope of np_shape
  • Loading branch information
reminisce committed May 29, 2019
1 parent 5fc4fc5 commit ab68e7e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
22 changes: 17 additions & 5 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1582,8 +1582,14 @@ static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

void NDArray::Save(dmlc::Stream *strm) const {
// TODO(junwu): Support this after NumPy operators are merged
CHECK(!Imperative::Get()->is_np_shape())
<< "Saving ndarray within the scope of np_shape is not supported.";
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(storage_type(), kDefaultStorage)
<< "only allow serializing ndarray of default storage type within the scope of np_shape";
CHECK_NE(shape_.Size(), 0U)
<< "serializing zero-size ndarray within the scope of np_shape is not supported";
CHECK_NE(shape_.ndim(), 0)
<< "serializing scalar ndarray within the scope of np_shape is not supported";
}
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);
Expand Down Expand Up @@ -1701,9 +1707,6 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
}

bool NDArray::Load(dmlc::Stream *strm) {
// TODO(junwu): Support this after NumPy operators are merged
CHECK(!Imperative::Get()->is_np_shape())
<< "Loading ndarray within the scope of np_shape is not supported.";
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
if (magic != NDARRAY_V2_MAGIC) {
Expand All @@ -1724,6 +1727,15 @@ bool NDArray::Load(dmlc::Stream *strm) {
// load shape
mxnet::TShape shape;
if (!shape.Load(strm)) return false;
// TODO(junwu): Support this after NumPy operators are merged
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(stype, kDefaultStorage)
<< "only allow deserializing ndarray of default storage type within the scope of np_shape";
CHECK_NE(shape.Size(), 0U)
<< "deserializing zero-size ndarray within the scope of np_shape is not supported";
CHECK_NE(shape.ndim(), 0)
<< "deserializing scalar ndarray within the scope of np_shape is not supported";
}
if (shape.ndim() == 0) {
*this = NDArray(); return true;
}
Expand Down
29 changes: 19 additions & 10 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,16 +1703,25 @@ def test_zero_from_numpy():

@with_seed()
def test_save_load_zero_size_ndarrays():
shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)]
array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
array_list = [mx.nd.array(arr) for arr in array_list]
with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'dataset')
mx.nd.save(fname, array_list)
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())
def check_save_load(is_np_shape, shapes, throw_exception):
with mx.np_shape(is_np_shape):
array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
array_list = [mx.nd.array(arr) for arr in array_list]
with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'dataset')
if throw_exception:
assert_exception(mx.nd.save, mx.MXNetError, fname, array_list)
else:
mx.nd.save(fname, array_list)
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())

check_save_load(False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False) # legacy mode
check_save_load(True, [(2, 1), (3, 5)], False) # np_shape semantics, no zero-size, should succeed
check_save_load(True, [(2, 1), (3, 0)], True) # np_shape semantics, zero-size, should fail
check_save_load(True, [(2, 1), ()], True) # np_shape semantics, scalar tensor, should fail


if __name__ == '__main__':
Expand Down

0 comments on commit ab68e7e

Please sign in to comment.