Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
/*!
* \file src/node/structural_hash.cc
*/
#include <dmlc/memory_io.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
Expand All @@ -30,6 +31,7 @@
#include <algorithm>
#include <unordered_map>

#include "../support/base64.h"
#include "../support/str_escape.h"
#include "../support/utils.h"

Expand Down Expand Up @@ -363,7 +365,24 @@ bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs,
}
}

TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait);
TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait)
.set_creator([](const std::string& blob) {
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
support::Base64InStream b64strm(&mstrm);
b64strm.InitPosition();
runtime::NDArray temp;
ICHECK(temp.Load(&b64strm));
return RefToObjectPtr::Get(temp);
})
.set_repr_bytes([](const Object* n) -> std::string {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
support::Base64OutStream b64strm(&mstrm);
const auto* ndarray = static_cast<const runtime::NDArray::Container*>(n);
runtime::SaveDLTensor(&b64strm, &ndarray->dl_tensor);
b64strm.Finish();
return blob;
});

struct ArrayNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
Expand Down
43 changes: 33 additions & 10 deletions tests/python/unittest/test_node_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import sys
import pytest
from tvm import te
import numpy as np


def test_const_saveload_json():
Expand Down Expand Up @@ -160,14 +162,35 @@ def test_dict():
assert set(dir(x.__class__)) <= set(dir(x))


def test_ndarray():
dev = tvm.cpu(0)
tvm_arr = tvm.nd.array(np.random.rand(4), device=dev)
tvm_arr2 = tvm.ir.load_json(tvm.ir.save_json(tvm_arr))
tvm.ir.assert_structural_equal(tvm_arr, tvm_arr2)
np.testing.assert_array_equal(tvm_arr.numpy(), tvm_arr2.numpy())


def test_ndarray_dict():
dev = tvm.cpu(0)
m1 = {
"key1": tvm.nd.array(np.random.rand(4), device=dev),
"key2": tvm.nd.array(np.random.rand(4), device=dev),
}
m2 = tvm.ir.load_json(tvm.ir.save_json(m1))
tvm.ir.assert_structural_equal(m1, m2)


def test_alloc_const():
dev = tvm.cpu(0)
dtype = "float32"
shape = (16,)
buf = tvm.tir.decl_buffer(shape, dtype)
data = tvm.nd.array(np.random.rand(*shape).astype(dtype), device=dev)
body = tvm.tir.Evaluate(0)
stmt = tvm.tir.AllocateConst(buf.data, dtype, shape, data, body)
stmt2 = tvm.ir.load_json(tvm.ir.save_json(stmt))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you double-check that you're actually exporting an NDArray into JSON here? i think this might fix it , but it would be great if we could prove it through e.g. non-empty b64ndarray key

Copy link
Member

@tqchen tqchen May 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My read is that it does not go through b64ndarray mechanism, but instead goes through repr_bytes mechanism as @vinx13 commented. The old b64ndarray mechanism should still continue to work. At some time pt perhaps we can move to repr_bytes mechanism for all cases while keeping b64ndarray for one cycle.

The structural equality in the next line should proves the serialization correctness

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a check of alloc_const.data against ref data

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @vinx13! did we ever figure out which mechanism is used to serialize the ndarray though?

Copy link
Member Author

@vinx13 vinx13 May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will enter this path https://github.com/apache/tvm/blob/main/src/node/serialization.cc#L156 and call the registered set_repr_bytes during indexing, so it won't be saved into b64array section. The node is later saved in https://github.com/apache/tvm/blob/main/src/node/serialization.cc#L276

tvm.ir.assert_structural_equal(stmt, stmt2)


if __name__ == "__main__":
test_string()
test_env_func()
test_make_node()
test_make_smap()
test_const_saveload_json()
test_make_sum()
test_pass_config()
test_dict()
test_infinity_value()
test_minmax_value()
sys.exit(pytest.main([__file__] + sys.argv[1:]))