diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index e35f0f4a351a..eac1944016df 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1974,7 +1974,8 @@ int MXNDArrayLoad(const char* fname, << "Failed to read 32 bits from file."; } - if (magic == 0x04034b50 || magic == 0x504b0304) { // zip file format; assumed to be npz + if (magic == 0x04034b50 || magic == 0x504b0304 || + magic == 0x06054b50 || magic == 0x504b0506) { // zip file format; assumed to be npz auto[data, names] = npz::load_arrays(fname); ret->ret_handles.resize(data.size()); for (size_t i = 0; i < data.size(); ++i) { diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 9f2c67b7db29..c6221c5c9eae 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1421,3 +1421,8 @@ def test_mixed_array_types_share_memory(): mx_pinned_array = mx_array.as_in_ctx(mx.cpu_pinned(0)) assert not _np.may_share_memory(np_array, mx_pinned_array) assert not _np.shares_memory(np_array, mx_pinned_array) + +@use_np +def test_save_load_empty(tmp_path): + mx.npx.savez(str(tmp_path / 'params.npz')) + mx.npx.load(str(tmp_path / 'params.npz'))