Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enable serializing/deserializing ndarrays in np_shape semantics #15090

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
<< "from.shape = " << from.shape() << " to.shape=" << to.shape();
CHECK(!mxnet::op::shape_is_none(from.shape()))
<< "source operands have undefined shape";
if (from.shape().Size() == 0U) return;
// important: callback must always capture by value
const Context from_ctx = from.ctx();
const int a = from_ctx.dev_mask();
Expand Down Expand Up @@ -1580,13 +1581,20 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;
/* magic number for ndarray version 2, with storage type */
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

// magic number for ndarray version 3, with np shape semantics.
// The ndarray must be saved and loaded within np shape semantics.
static const uint32_t NDARRAY_V3_MAGIC = 0xF993faca;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any tests for handling legacy storage types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


void NDArray::Save(dmlc::Stream *strm) const {
// TODO(junwu): Support this after NumPy operators are merged
CHECK(!Imperative::Get()->is_np_shape())
<< "Saving ndarray within the scope of np_shape is not supported.";
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(storage_type(), kDefaultStorage)
<< "only allow serializing ndarray of default storage type in np shape semantics";
strm->Write(NDARRAY_V3_MAGIC);
} else {
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);
}

// save storage type
int32_t stype = storage_type();
Expand Down Expand Up @@ -1701,18 +1709,30 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
}

bool NDArray::Load(dmlc::Stream *strm) {
// TODO(junwu): Support this after NumPy operators are merged
CHECK(!Imperative::Get()->is_np_shape())
<< "Loading ndarray within the scope of np_shape is not supported.";
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
if (magic != NDARRAY_V2_MAGIC) {
if (magic == NDARRAY_V3_MAGIC) {
CHECK(Imperative::Get()->is_np_shape())
<< "ndarray was saved in np shape semantics, must be loaded in the same semantics."
" Please turn on np shape semantics in Python using `with np_shape(True)`"
" or decorator `use_np_shape` to scope the code of loading the ndarray.";
} else {
CHECK(!Imperative::Get()->is_np_shape())
<< "ndarray was not saved in np shape semantics, but being loaded in np shape semantics."
" Please turn off np shape semantics in Python using `with np_shape(False)`"
" to scope of the code of loading the ndarray.";
}
if (magic != NDARRAY_V2_MAGIC && magic != NDARRAY_V3_MAGIC) {
return LegacyLoad(strm, magic);
}

// load storage type
int32_t stype;
if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;
if (Imperative::Get()->is_np_shape()) {
CHECK_EQ(stype, kDefaultStorage)
<< "only allow deserializing ndarray of default storage type in np shape semantics";
}
const int32_t nad = num_aux_data(static_cast<NDArrayStorageType>(stype));

// load storage shape
Expand All @@ -1724,7 +1744,12 @@ bool NDArray::Load(dmlc::Stream *strm) {
// load shape
mxnet::TShape shape;
if (!shape.Load(strm)) return false;
if (shape.ndim() == 0) {
if (Imperative::Get()->is_np_shape()) {
if (!shape_is_known(shape)) {
*this = NDArray();
return true;
}
} else if (shape.ndim() == 0) {
*this = NDArray(); return true;
}

Expand Down
35 changes: 24 additions & 11 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,17 +1702,30 @@ def test_zero_from_numpy():


@with_seed()
def test_save_load_zero_size_ndarrays():
shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)]
array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
array_list = [mx.nd.array(arr) for arr in array_list]
with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'dataset')
mx.nd.save(fname, array_list)
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())
def test_save_load_scalar_zero_size_ndarrays():
def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_exception, load_throw_exception):
with mx.np_shape(save_is_np_shape):
array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
array_list = [mx.nd.array(arr) for arr in array_list]
with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'dataset')
if save_throw_exception:
assert_exception(mx.nd.save, mx.MXNetError, fname, array_list)
else:
mx.nd.save(fname, array_list)
with mx.np_shape(load_is_np_shape):
if load_throw_exception:
assert_exception(mx.nd.load, mx.MXNetError, fname)
else:
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())

check_save_load(False, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False)
check_save_load(True, False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True)
check_save_load(False, True, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, True)
check_save_load(True, True, [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False)


if __name__ == '__main__':
Expand Down