diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 4d82f1e38b5e..e97e5f41bfc2 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -19,6 +19,7 @@ /*! * \file src/node/structural_hash.cc */ +#include #include #include #include @@ -30,6 +31,7 @@ #include #include +#include "../support/base64.h" #include "../support/str_escape.h" #include "../support/utils.h" @@ -363,7 +365,24 @@ bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, } } -TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait); +TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait) + .set_creator([](const std::string& blob) { + dmlc::MemoryStringStream mstrm(const_cast(&blob)); + support::Base64InStream b64strm(&mstrm); + b64strm.InitPosition(); + runtime::NDArray temp; + ICHECK(temp.Load(&b64strm)); + return RefToObjectPtr::Get(temp); + }) + .set_repr_bytes([](const Object* n) -> std::string { + std::string blob; + dmlc::MemoryStringStream mstrm(&blob); + support::Base64OutStream b64strm(&mstrm); + const auto* ndarray = static_cast(n); + runtime::SaveDLTensor(&b64strm, &ndarray->dl_tensor); + b64strm.Finish(); + return blob; + }); struct ArrayNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index c1298b56f7fb..bb300607cfd6 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm +import sys import pytest from tvm import te +import numpy as np def test_const_saveload_json(): @@ -160,14 +162,37 @@ def test_dict(): assert set(dir(x.__class__)) <= set(dir(x)) +def test_ndarray(): + dev = tvm.cpu(0) + tvm_arr = tvm.nd.array(np.random.rand(4), device=dev) + tvm_arr2 = tvm.ir.load_json(tvm.ir.save_json(tvm_arr)) + tvm.ir.assert_structural_equal(tvm_arr, tvm_arr2) + np.testing.assert_array_equal(tvm_arr.numpy(), tvm_arr2.numpy()) + + +def test_ndarray_dict(): + dev = tvm.cpu(0) + m1 = { + "key1": tvm.nd.array(np.random.rand(4), device=dev), + "key2": tvm.nd.array(np.random.rand(4), device=dev), + } + m2 = tvm.ir.load_json(tvm.ir.save_json(m1)) + tvm.ir.assert_structural_equal(m1, m2) + + +def test_alloc_const(): + dev = tvm.cpu(0) + dtype = "float32" + shape = (16,) + buf = tvm.tir.decl_buffer(shape, dtype) + np_data = np.random.rand(*shape).astype(dtype) + data = tvm.nd.array(np_data, device=dev) + body = tvm.tir.Evaluate(0) + alloc_const = tvm.tir.AllocateConst(buf.data, dtype, shape, data, body) + alloc_const2 = tvm.ir.load_json(tvm.ir.save_json(alloc_const)) + tvm.ir.assert_structural_equal(alloc_const, alloc_const2) + np.testing.assert_array_equal(np_data, alloc_const2.data.numpy()) + + if __name__ == "__main__": - test_string() - test_env_func() - test_make_node() - test_make_smap() - test_const_saveload_json() - test_make_sum() - test_pass_config() - test_dict() - test_infinity_value() - test_minmax_value() + sys.exit(pytest.main([__file__] + sys.argv[1:]))