Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Enable serializing/deserializing ndarrays in np_shape semantics (#15090)
Browse files Browse the repository at this point in the history
* Loosen the contraint on serializing/deserializing ndarrays within the scope of np_shape

* Support save/load dense ndarrays in np_shape semantics
  • Loading branch information
reminisce authored and zheng-da committed Jun 1, 2019
1 parent 866ec10 commit e8a20fb
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
47 changes: 36 additions & 11 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
<< "from.shape = " << from.shape() << " to.shape=" << to.shape();
CHECK(!mxnet::op::shape_is_none(from.shape()))
<< "source operands have undefined shape";
if (from.shape().Size() == 0U) return;
// important: callback must always capture by value
const Context from_ctx = from.ctx();
const int a = from_ctx.dev_mask();
Expand Down Expand Up @@ -1580,13 +1581,20 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;
/* magic number for ndarray version 2, with storage type */
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

// magic number for ndarray version 3, with np shape semantics.
// The ndarray must be saved and loaded within np shape semantics.
static const uint32_t NDARRAY_V3_MAGIC = 0xF993faca;

void NDArray::Save(dmlc::Stream *strm) const {
// TODO(junwu): Support this after NumPy operators are merged
CHECK(!Imperative::Get()->is_np_shape())
<< "Saving ndarray within the scope of np_shape is not supported.";
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(storage_type(), kDefaultStorage)
<< "only allow serializing ndarray of default storage type in np shape semantics";
strm->Write(NDARRAY_V3_MAGIC);
} else {
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);
}

// save storage type
int32_t stype = storage_type();
Expand Down Expand Up @@ -1701,18 +1709,30 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
}

bool NDArray::Load(dmlc::Stream *strm) {
// TODO(junwu): Support this after NumPy operators are merged
CHECK(!Imperative::Get()->is_np_shape())
<< "Loading ndarray within the scope of np_shape is not supported.";
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
if (magic != NDARRAY_V2_MAGIC) {
if (magic == NDARRAY_V3_MAGIC) {
CHECK(Imperative::Get()->is_np_shape())
<< "ndarray was saved in np shape semantics, must be loaded in the same semantics."
" Please turn on np shape semantics in Python using `with np_shape(True)`"
" or decorator `use_np_shape` to scope the code of loading the ndarray.";
} else {
CHECK(!Imperative::Get()->is_np_shape())
<< "ndarray was not saved in np shape semantics, but being loaded in np shape semantics."
" Please turn off np shape semantics in Python using `with np_shape(False)`"
" to scope of the code of loading the ndarray.";
}
if (magic != NDARRAY_V2_MAGIC && magic != NDARRAY_V3_MAGIC) {
return LegacyLoad(strm, magic);
}

// load storage type
int32_t stype;
if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(stype, kDefaultStorage)
<< "only allow deserializing ndarray of default storage type in np shape semantics";
}
const int32_t nad = num_aux_data(static_cast<NDArrayStorageType>(stype));

// load storage shape
Expand All @@ -1724,7 +1744,12 @@ bool NDArray::Load(dmlc::Stream *strm) {
// load shape
mxnet::TShape shape;
if (!shape.Load(strm)) return false;
if (shape.ndim() == 0) {
if (Imperative::Get()->is_np_shape()) {
if (!shape_is_known(shape)) {
*this = NDArray();
return true;
}
} else if (shape.ndim() == 0) {
*this = NDArray(); return true;
}

Expand Down
35 changes: 24 additions & 11 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,17 +1702,30 @@ def test_zero_from_numpy():


@with_seed()
def test_save_load_zero_size_ndarrays():
shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)]
array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
array_list = [mx.nd.array(arr) for arr in array_list]
with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'dataset')
mx.nd.save(fname, array_list)
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())
def test_save_load_scalar_zero_size_ndarrays():
def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_exception, load_throw_exception):
with mx.np_shape(save_is_np_shape):
array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
array_list = [mx.nd.array(arr) for arr in array_list]
with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'dataset')
if save_throw_exception:
assert_exception(mx.nd.save, mx.MXNetError, fname, array_list)
else:
mx.nd.save(fname, array_list)
with mx.np_shape(load_is_np_shape):
if load_throw_exception:
assert_exception(mx.nd.load, mx.MXNetError, fname)
else:
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())

check_save_load(False, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False)
check_save_load(True, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True)
check_save_load(False, True, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True)
check_save_load(True, True, [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False)


if __name__ == '__main__':
Expand Down

0 comments on commit e8a20fb

Please sign in to comment.