Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Support save/load dense ndarrays in np_shape semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 31, 2019
1 parent 189ff8f commit 610606b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
51 changes: 32 additions & 19 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
<< "from.shape = " << from.shape() << " to.shape=" << to.shape();
CHECK(!mxnet::op::shape_is_none(from.shape()))
<< "source operands have undefined shape";
if (from.shape().Size() == 0U) return;
// important: callback must always capture by value
const Context from_ctx = from.ctx();
const int a = from_ctx.dev_mask();
Expand Down Expand Up @@ -1580,19 +1581,20 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;
/* magic number for ndarray version 2, with storage type */
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

// magic number for ndarray version 3, with np shape semantics.
// The ndarray must be saved and loaded within np shape semantics.
static const uint32_t NDARRAY_V3_MAGIC = 0xF993faca;

void NDArray::Save(dmlc::Stream *strm) const {
// TODO(junwu): Support this after NumPy operators are merged
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(storage_type(), kDefaultStorage)
<< "only allow serializing ndarray of default storage type within the scope of np_shape";
CHECK_NE(shape_.Size(), 0U)
<< "serializing zero-size ndarray within the scope of np_shape is not supported";
CHECK_NE(shape_.ndim(), 0)
<< "serializing scalar ndarray within the scope of np_shape is not supported";
<< "only allow serializing ndarray of default storage type in np shape semantics";
strm->Write(NDARRAY_V3_MAGIC);
} else {
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);
}
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);

// save storage type
int32_t stype = storage_type();
Expand Down Expand Up @@ -1709,13 +1711,28 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
bool NDArray::Load(dmlc::Stream *strm) {
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
if (magic != NDARRAY_V2_MAGIC) {
if (magic == NDARRAY_V3_MAGIC) {
CHECK(Imperative::Get()->is_np_shape())
<< "ndarray was saved in np shape semantics, must be loaded in the same semantics."
" Please turn on np shape semantics in Python using `with np_shape(True)`"
" or decorator `use_np_shape` to scope the code of loading the ndarray.";
} else {
CHECK(!Imperative::Get()->is_np_shape())
<< "ndarray was not saved in np shape semantics, but being loaded in np shape semantics."
" Please turn off np shape semantics in Python using `with np_shape(False)`"
" to scope of the code of loading the ndarray.";
}
if (magic != NDARRAY_V2_MAGIC && magic != NDARRAY_V3_MAGIC) {
return LegacyLoad(strm, magic);
}

// load storage type
int32_t stype;
if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(stype, kDefaultStorage)
<< "only allow deserializing ndarray of default storage type in np shape semantics";
}
const int32_t nad = num_aux_data(static_cast<NDArrayStorageType>(stype));

// load storage shape
Expand All @@ -1727,16 +1744,12 @@ bool NDArray::Load(dmlc::Stream *strm) {
// load shape
mxnet::TShape shape;
if (!shape.Load(strm)) return false;
// TODO(junwu): Support this after NumPy operators are merged
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(stype, kDefaultStorage)
<< "only allow deserializing ndarray of default storage type within the scope of np_shape";
CHECK_NE(shape.Size(), 0U)
<< "deserializing zero-size ndarray within the scope of np_shape is not supported";
CHECK_NE(shape.ndim(), 0)
<< "deserializing scalar ndarray within the scope of np_shape is not supported";
}
if (shape.ndim() == 0) {
if (!shape_is_known(shape)) {
*this = NDArray();
return true;
}
} else if (shape.ndim() == 0) {
*this = NDArray(); return true;
}

Expand Down
30 changes: 17 additions & 13 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,26 +1702,30 @@ def test_zero_from_numpy():


@with_seed()
def test_save_load_zero_size_ndarrays():
def check_save_load(is_np_shape, shapes, throw_exception):
with mx.np_shape(is_np_shape):
def test_save_load_scalar_zero_size_ndarrays():
def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_exception, load_throw_exception):
with mx.np_shape(save_is_np_shape):
array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
array_list = [mx.nd.array(arr) for arr in array_list]
with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'dataset')
if throw_exception:
if save_throw_exception:
assert_exception(mx.nd.save, mx.MXNetError, fname, array_list)
else:
mx.nd.save(fname, array_list)
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())

check_save_load(False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False) # legacy mode
check_save_load(True, [(2, 1), (3, 5)], False) # np_shape semantics, no zero-size, should succeed
check_save_load(True, [(2, 1), (3, 0)], True) # np_shape semantics, zero-size, should fail
check_save_load(True, [(2, 1), ()], True) # np_shape semantics, scalar tensor, should fail
with mx.np_shape(load_is_np_shape):
if load_throw_exception:
assert_exception(mx.nd.load, mx.MXNetError, fname)
else:
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())

check_save_load(False, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False)
check_save_load(True, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True)
check_save_load(False, True, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True)
check_save_load(True, True, [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False)


if __name__ == '__main__':
Expand Down

0 comments on commit 610606b

Please sign in to comment.