Skip to content

Commit d582b7e

Browse files
authored
[CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple (#13671)
This PR add struct equal/hash and json serialization support for shape tuple. Testcases added.
1 parent 8551a5c commit d582b7e

File tree

4 files changed

+70
-2
lines changed

4 files changed

+70
-2
lines changed

src/node/structural_hash.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,50 @@ TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
484484
return ::tvm::runtime::make_object<ArrayNode>();
485485
});
486486

487+
struct ShapeTupleObjTrait {
488+
static constexpr const std::nullptr_t VisitAttrs = nullptr;
489+
490+
static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) {
491+
hash_reduce(self->size);
492+
for (size_t i = 0; i < self->size; ++i) {
493+
hash_reduce(self->data[i]);
494+
}
495+
}
496+
497+
static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs,
498+
SEqualReducer equal) {
499+
if (lhs->size != rhs->size) return false;
500+
for (size_t i = 0; i < lhs->size; ++i) {
501+
if (!equal(lhs->data[i], rhs->data[i])) return false;
502+
}
503+
return true;
504+
}
505+
};
506+
507+
TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait)
508+
.set_creator([](const std::string& blob) {
509+
// Store shape tuple in blob to avoid large integer overflow in JSON.
510+
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
511+
support::Base64InStream b64strm(&mstrm);
512+
b64strm.InitPosition();
513+
uint64_t size;
514+
b64strm.Read<uint64_t>(&size);
515+
std::vector<int64_t> data(size);
516+
b64strm.ReadArray(data.data(), size);
517+
ShapeTuple shape(data);
518+
return RefToObjectPtr::Get(shape);
519+
})
520+
.set_repr_bytes([](const Object* n) -> std::string {
521+
std::string blob;
522+
dmlc::MemoryStringStream mstrm(&blob);
523+
support::Base64OutStream b64strm(&mstrm);
524+
const auto* shape = static_cast<const runtime::ShapeTupleObj*>(n);
525+
b64strm.Write<uint64_t>(shape->size);
526+
b64strm.WriteArray(shape->data, shape->size);
527+
b64strm.Finish();
528+
return blob;
529+
});
530+
487531
struct MapNodeTrait {
488532
static constexpr const std::nullptr_t VisitAttrs = nullptr;
489533

src/support/base64.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,10 @@ class Base64InStream : public dmlc::Stream {
115115
}
116116
/*! \brief whether current position is end of a base64 stream */
117117
bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); }
118+
119+
using dmlc::Stream::Read;
118120
// override read function.
119-
virtual size_t Read(void* ptr, size_t size) {
121+
size_t Read(void* ptr, size_t size) final {
120122
using base64::DecodeTable;
121123
if (size == 0) return 0;
122124
// use tlen to record left size
@@ -224,7 +226,10 @@ class Base64InStream : public dmlc::Stream {
224226
class Base64OutStream : public dmlc::Stream {
225227
public:
226228
explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {}
227-
virtual void Write(const void* ptr, size_t size) {
229+
230+
using dmlc::Stream::Write;
231+
232+
void Write(const void* ptr, size_t size) final {
228233
using base64::EncodeTable;
229234
size_t tlen = size;
230235
const unsigned char* cptr = static_cast<const unsigned char*>(ptr);

tests/python/unittest/test_container_structural_equal.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,20 @@ def test_array_structural_equal_to_self(contents):
107107
assert get_first_mismatch_ensure_symmetry(a, b) is None
108108

109109

110+
@pytest.mark.parametrize(
111+
"contents",
112+
[
113+
[],
114+
[1],
115+
[1, 2, 3],
116+
],
117+
)
118+
def test_shape_tuple_structural_equal_to_self(contents):
119+
a = tvm.runtime.ShapeTuple(list(contents))
120+
b = tvm.runtime.ShapeTuple(list(contents))
121+
assert get_first_mismatch_ensure_symmetry(a, b) is None
122+
123+
110124
@pytest.mark.parametrize(
111125
"a, b, expected_a_path, expected_b_path",
112126
[

tests/python/unittest/test_runtime_container.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def test_shape_tuple():
9090
# ShapleTuple vs. ShapeTuple
9191
assert stuple == _container.ShapeTuple(shape)
9292

93+
# test pickle
94+
z = pickle.loads(pickle.dumps(stuple))
95+
assert isinstance(z, tvm.runtime.ShapeTuple)
96+
assert stuple == z
97+
9398

9499
if __name__ == "__main__":
95100
test_string()

0 commit comments

Comments
 (0)