@@ -60,6 +60,7 @@ class AnyView {
6060 void reset () {
6161 data_.type_index = TypeIndex::kTVMFFINone ;
6262 // invariance: always set the union padding part to 0
63+ data_.zero_padding = 0 ;
6364 data_.v_int64 = 0 ;
6465 }
6566 /* !
@@ -72,6 +73,7 @@ class AnyView {
7273 // default constructors
7374 AnyView () {
7475 data_.type_index = TypeIndex::kTVMFFINone ;
76+ data_.zero_padding = 0 ;
7577 data_.v_int64 = 0 ;
7678 }
7779 ~AnyView () = default ;
@@ -80,6 +82,7 @@ class AnyView {
8082 AnyView& operator =(const AnyView&) = default ;
8183 AnyView (AnyView&& other) : data_(other.data_) {
8284 other.data_ .type_index = TypeIndex::kTVMFFINone ;
85+ other.data_ .zero_padding = 0 ;
8386 other.data_ .v_int64 = 0 ;
8487 }
8588 TVM_FFI_INLINE AnyView& operator =(AnyView&& other) {
@@ -198,22 +201,19 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data,
198201 if (data->type_index == TypeIndex::kTVMFFIRawStr ) {
199202 // convert raw string to owned string object
200203 String temp (data->v_c_str );
201- data->type_index = TypeIndex::kTVMFFIStr ;
202- data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr (std::move (temp));
204+ TypeTraits<String>::MoveToAny (std::move (temp), data);
203205 } else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr ) {
204206 // convert byte array to owned bytes object
205207 Bytes temp (*static_cast <TVMFFIByteArray*>(data->v_ptr ));
206- data->type_index = TypeIndex::kTVMFFIBytes ;
207- data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr (std::move (temp));
208+ TypeTraits<Bytes>::MoveToAny (std::move (temp), data);
208209 } else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef ) {
209210 // convert rvalue ref to owned object
210211 Object** obj_addr = static_cast <Object**>(data->v_ptr );
211212 TVM_FFI_ICHECK (obj_addr[0 ] != nullptr ) << " RValueRef already moved" ;
212213 ObjectRef temp (details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj_addr[0 ]));
213214 // set the rvalue ref to nullptr to avoid double move
214215 obj_addr[0 ] = nullptr ;
215- data->type_index = temp->type_index ();
216- data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr (std::move (temp));
216+ TypeTraits<ObjectRef>::MoveToAny (std::move (temp), data);
217217 }
218218 }
219219}
@@ -239,6 +239,7 @@ class Any {
239239 details::ObjectUnsafe::DecRefObjectHandle (data_.v_obj );
240240 }
241241 data_.type_index = TVMFFITypeIndex::kTVMFFINone ;
242+ data_.zero_padding = 0 ;
242243 data_.v_int64 = 0 ;
243244 }
244245 /* !
@@ -251,6 +252,7 @@ class Any {
251252 // default constructors
252253 Any () {
253254 data_.type_index = TypeIndex::kTVMFFINone ;
255+ data_.zero_padding = 0 ;
254256 data_.v_int64 = 0 ;
255257 }
256258 ~Any () { this ->reset (); }
@@ -262,6 +264,7 @@ class Any {
262264 }
263265 Any (Any&& other) : data_(other.data_) {
264266 other.data_ .type_index = TypeIndex::kTVMFFINone ;
267+ other.data_ .zero_padding = 0 ;
265268 other.data_ .v_int64 = 0 ;
266269 }
267270 TVM_FFI_INLINE Any& operator =(const Any& other) {
@@ -408,7 +411,8 @@ class Any {
408411 * \return True if the two Any are same type and value, false otherwise.
409412 */
410413 TVM_FFI_INLINE bool same_as (const Any& other) const noexcept {
411- return data_.type_index == other.data_ .type_index && data_.v_int64 == other.data_ .v_int64 ;
414+ return data_.type_index == other.data_ .type_index &&
415+ data_.zero_padding == other.data_ .zero_padding && data_.v_int64 == other.data_ .v_int64 ;
412416 }
413417
414418 /*
@@ -485,6 +489,7 @@ struct AnyUnsafe : public ObjectUnsafe {
485489 TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny (Any&& any) {
486490 TVMFFIAny result = any.data_ ;
487491 any.data_ .type_index = TypeIndex::kTVMFFINone ;
492+ any.data_ .zero_padding = 0 ;
488493 any.data_ .v_int64 = 0 ;
489494 return result;
490495 }
@@ -493,6 +498,7 @@ struct AnyUnsafe : public ObjectUnsafe {
493498 Any any;
494499 any.data_ = data;
495500 data.type_index = TypeIndex::kTVMFFINone ;
501+ data.zero_padding = 0 ;
496502 data.v_int64 = 0 ;
497503 return any;
498504 }
@@ -543,17 +549,24 @@ struct AnyHash {
543549 * \return Hash code of a, string hash for strings and pointer address otherwise.
544550 */
545551 uint64_t operator ()(const Any& src) const {
546- uint64_t val_hash = [&]() -> uint64_t {
547- if (src.data_ .type_index == TypeIndex::kTVMFFIStr ||
548- src.data_ .type_index == TypeIndex::kTVMFFIBytes ) {
549- const details::BytesObjBase* src_str =
550- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
551- return details::StableHashBytes (src_str->data , src_str->size );
552- } else {
553- return src.data_ .v_uint64 ;
554- }
555- }();
556- return details::StableHashCombine (src.data_ .type_index , val_hash);
552+ if (src.data_ .type_index == TypeIndex::kTVMFFISmallStr ) {
553+ // for small string, we use the same type key hash as normal string
554+ // so heap allocated string and on stack string will have the same hash
555+ return details::StableHashCombine (TypeIndex::kTVMFFIStr ,
556+ details::StableHashSmallStrBytes (&src.data_ ));
557+ } else if (src.data_ .type_index == TypeIndex::kTVMFFISmallBytes ) {
558+ // use byte the same type key as bytes
559+ return details::StableHashCombine (TypeIndex::kTVMFFIBytes ,
560+ details::StableHashSmallStrBytes (&src.data_ ));
561+ } else if (src.data_ .type_index == TypeIndex::kTVMFFIStr ||
562+ src.data_ .type_index == TypeIndex::kTVMFFIBytes ) {
563+ const details::BytesObjBase* src_str =
564+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
565+ return details::StableHashCombine (src.data_ .type_index ,
566+ details::StableHashBytes (src_str->data , src_str->size ));
567+ } else {
568+ return details::StableHashCombine (src.data_ .type_index , src.data_ .v_uint64 );
569+ }
557570 }
558571};
559572
@@ -566,19 +579,60 @@ struct AnyEqual {
566579 * \return String equality if both are strings, pointer address equality otherwise.
567580 */
568581 bool operator ()(const Any& lhs, const Any& rhs) const {
569- if (lhs.data_ .type_index != rhs.data_ .type_index ) return false ;
570- // byte equivalence
571- if (lhs.data_ .v_int64 == rhs.data_ .v_int64 ) return true ;
572- // specialy handle string hash
573- if (lhs.data_ .type_index == TypeIndex::kTVMFFIStr ||
574- lhs.data_ .type_index == TypeIndex::kTVMFFIBytes ) {
575- const details::BytesObjBase* lhs_str =
576- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
577- const details::BytesObjBase* rhs_str =
578- details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
579- return Bytes::memequal (lhs_str->data , rhs_str->data , lhs_str->size , rhs_str->size );
582+ // header with type index
583+ const int64_t * lhs_as_int64 = reinterpret_cast <const int64_t *>(&lhs.data_ );
584+ const int64_t * rhs_as_int64 = reinterpret_cast <const int64_t *>(&rhs.data_ );
585+ static_assert (sizeof (TVMFFIAny) == 16 );
586+ // fast path, check byte equality
587+ if (lhs_as_int64[0 ] == rhs_as_int64[0 ] && lhs_as_int64[1 ] == rhs_as_int64[1 ]) {
588+ return true ;
589+ }
590+ // common false case type index match, in this case we only need to pay attention to string
591+ // equality
592+ if (lhs.data_ .type_index == rhs.data_ .type_index ) {
593+ // specialy handle string hash
594+ if (lhs.data_ .type_index == TypeIndex::kTVMFFIStr ||
595+ lhs.data_ .type_index == TypeIndex::kTVMFFIBytes ) {
596+ const details::BytesObjBase* lhs_str =
597+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
598+ const details::BytesObjBase* rhs_str =
599+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
600+ return Bytes::memequal (lhs_str->data , rhs_str->data , lhs_str->size , rhs_str->size );
601+ }
602+ return false ;
603+ } else {
604+ // type_index mismatch, if index is not string, return false
605+ if (lhs.data_ .type_index != kTVMFFIStr && lhs.data_ .type_index != kTVMFFISmallStr &&
606+ lhs.data_ .type_index != kTVMFFISmallBytes && lhs.data_ .type_index != kTVMFFIBytes ) {
607+ return false ;
608+ }
609+ // small string and normal string comparison
610+ if (lhs.data_ .type_index == kTVMFFIStr && rhs.data_ .type_index == kTVMFFISmallStr ) {
611+ const details::BytesObjBase* lhs_str =
612+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
613+ return Bytes::memequal (lhs_str->data , rhs.data_ .v_bytes , lhs_str->size ,
614+ rhs.data_ .small_str_len );
615+ }
616+ if (lhs.data_ .type_index == kTVMFFISmallStr && rhs.data_ .type_index == kTVMFFIStr ) {
617+ const details::BytesObjBase* rhs_str =
618+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
619+ return Bytes::memequal (lhs.data_ .v_bytes , rhs_str->data , lhs.data_ .small_str_len ,
620+ rhs_str->size );
621+ }
622+ if (lhs.data_ .type_index == kTVMFFIBytes && rhs.data_ .type_index == kTVMFFISmallBytes ) {
623+ const details::BytesObjBase* lhs_bytes =
624+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
625+ return Bytes::memequal (lhs_bytes->data , rhs.data_ .v_bytes , lhs_bytes->size ,
626+ rhs.data_ .small_str_len );
627+ }
628+ if (lhs.data_ .type_index == kTVMFFISmallBytes && rhs.data_ .type_index == kTVMFFIBytes ) {
629+ const details::BytesObjBase* rhs_bytes =
630+ details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
631+ return Bytes::memequal (lhs.data_ .v_bytes , rhs_bytes->data , lhs.data_ .small_str_len ,
632+ rhs_bytes->size );
633+ }
634+ return false ;
580635 }
581- return false ;
582636 }
583637};
584638
0 commit comments