Skip to content

Commit e7f1224

Browse files
authored
Fix json serialization for NDArray (#11303)
When `NDArray` is being stored as `ObjectRef`, the serializer won't trigger the right path for storage. Under the new serialization mode, we need to be able to leverage the `repr_bytes` mechanism to save `NDArray`. This change is backward compatible -- ndarray saved in previous format will continue to work. And fixes the problem of serialization when `NDArray` is involved as part of `ObjectRef`. In the future, we can consider consolidate the `NDArray` save into the `repr_bytes` and remove the specialization as we evolve to newer versions
1 parent bd029cb commit e7f1224

File tree

2 files changed

+55
-11
lines changed

2 files changed

+55
-11
lines changed

src/node/structural_hash.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
/*!
2020
* \file src/node/structural_hash.cc
2121
*/
22+
#include <dmlc/memory_io.h>
2223
#include <tvm/node/functor.h>
2324
#include <tvm/node/node.h>
2425
#include <tvm/node/reflection.h>
@@ -30,6 +31,7 @@
3031
#include <algorithm>
3132
#include <unordered_map>
3233

34+
#include "../support/base64.h"
3335
#include "../support/str_escape.h"
3436
#include "../support/utils.h"
3537

@@ -363,7 +365,24 @@ bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs,
363365
}
364366
}
365367

366-
TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait);
368+
TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait)
369+
.set_creator([](const std::string& blob) {
370+
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
371+
support::Base64InStream b64strm(&mstrm);
372+
b64strm.InitPosition();
373+
runtime::NDArray temp;
374+
ICHECK(temp.Load(&b64strm));
375+
return RefToObjectPtr::Get(temp);
376+
})
377+
.set_repr_bytes([](const Object* n) -> std::string {
378+
std::string blob;
379+
dmlc::MemoryStringStream mstrm(&blob);
380+
support::Base64OutStream b64strm(&mstrm);
381+
const auto* ndarray = static_cast<const runtime::NDArray::Container*>(n);
382+
runtime::SaveDLTensor(&b64strm, &ndarray->dl_tensor);
383+
b64strm.Finish();
384+
return blob;
385+
});
367386

368387
struct ArrayNodeTrait {
369388
static constexpr const std::nullptr_t VisitAttrs = nullptr;

tests/python/unittest/test_node_reflection.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import tvm
18+
import sys
1819
import pytest
1920
from tvm import te
21+
import numpy as np
2022

2123

2224
def test_const_saveload_json():
@@ -160,14 +162,37 @@ def test_dict():
160162
assert set(dir(x.__class__)) <= set(dir(x))
161163

162164

165+
def test_ndarray():
166+
dev = tvm.cpu(0)
167+
tvm_arr = tvm.nd.array(np.random.rand(4), device=dev)
168+
tvm_arr2 = tvm.ir.load_json(tvm.ir.save_json(tvm_arr))
169+
tvm.ir.assert_structural_equal(tvm_arr, tvm_arr2)
170+
np.testing.assert_array_equal(tvm_arr.numpy(), tvm_arr2.numpy())
171+
172+
173+
def test_ndarray_dict():
174+
dev = tvm.cpu(0)
175+
m1 = {
176+
"key1": tvm.nd.array(np.random.rand(4), device=dev),
177+
"key2": tvm.nd.array(np.random.rand(4), device=dev),
178+
}
179+
m2 = tvm.ir.load_json(tvm.ir.save_json(m1))
180+
tvm.ir.assert_structural_equal(m1, m2)
181+
182+
183+
def test_alloc_const():
184+
dev = tvm.cpu(0)
185+
dtype = "float32"
186+
shape = (16,)
187+
buf = tvm.tir.decl_buffer(shape, dtype)
188+
np_data = np.random.rand(*shape).astype(dtype)
189+
data = tvm.nd.array(np_data, device=dev)
190+
body = tvm.tir.Evaluate(0)
191+
alloc_const = tvm.tir.AllocateConst(buf.data, dtype, shape, data, body)
192+
alloc_const2 = tvm.ir.load_json(tvm.ir.save_json(alloc_const))
193+
tvm.ir.assert_structural_equal(alloc_const, alloc_const2)
194+
np.testing.assert_array_equal(np_data, alloc_const2.data.numpy())
195+
196+
163197
if __name__ == "__main__":
164-
test_string()
165-
test_env_func()
166-
test_make_node()
167-
test_make_smap()
168-
test_const_saveload_json()
169-
test_make_sum()
170-
test_pass_config()
171-
test_dict()
172-
test_infinity_value()
173-
test_minmax_value()
198+
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)