@@ -221,6 +221,16 @@ class MapObj : public Object {
221221 uint64_t size_;
222222 /* ! \brief number of slots */
223223 uint64_t slots_;
224+ /* !
225+ * \brief Small layout tag mask
226+ * \note The most significant bit is used to indicate the small map layout.
227+ */
228+ static constexpr uint64_t kSmallTagMask = static_cast <uint64_t >(1 ) << 63 ;
229+ /* !
230+ * \brief Check if the map is a small map
231+ * \return True if the map is a small map
232+ */
233+ bool IsSmallMap () const { return (slots_ & kSmallTagMask ) != 0ull ; }
224234 /* !
225235 * \brief Optional data deleter when data is allocated separately
226236 * and its deletion is not managed by MapObj::deleter_.
@@ -242,6 +252,13 @@ class SmallMapObj : public MapObj,
242252 using MapObj::iterator;
243253 using MapObj::KVType;
244254
255+ // Return the number of usable slots for Small layout (mask off tag).
256+ /* !
257+ * \brief Return the number of usable slots for Small layout (mask off tag).
258+ * \return The number of usable slots
259+ */
260+ uint64_t NumSlots () const { return slots_ & ~kSmallTagMask ; }
261+
245262 ~SmallMapObj () {
246263 KVType* begin = static_cast <KVType*>(data_);
247264 for (uint64_t index = 0 ; index < size_; ++index) {
@@ -310,6 +327,11 @@ class SmallMapObj : public MapObj,
310327 void erase (const iterator& position) { Erase (position.index ); }
311328
312329 private:
330+ /* !
331+ * \brief Set the number of slots and attach tags bit.
332+ * \param n The number of slots
333+ */
334+ void SetSlotsAndSmallLayoutTag (uint64_t n) { slots_ = (n & ~kSmallTagMask ) | kSmallTagMask ; }
313335 /* !
314336 * \brief Remove a position in SmallMapObj
315337 * \param index The position to be removed
@@ -344,7 +366,7 @@ class SmallMapObj : public MapObj,
344366 ObjectPtr<SmallMapObj> p = make_inplace_array_object<SmallMapObj, KVType>(n);
345367 p->data_ = p->AddressOf (0 );
346368 p->size_ = 0 ;
347- p->slots_ = n ;
369+ p->SetSlotsAndSmallLayoutTag (n) ;
348370 return p;
349371 }
350372 /* !
@@ -386,15 +408,15 @@ class SmallMapObj : public MapObj,
386408 itr->second = kv.second ;
387409 return ;
388410 }
389- if (map_node->size_ < map_node->slots_ ) {
411+ if (map_node->size_ < map_node->NumSlots () ) {
390412 KVType* ptr = static_cast <KVType*>(map_node->data_ ) + map_node->size_ ;
391413 new (ptr) KVType (std::move (kv));
392414 ++map_node->size_ ;
393415 return ;
394416 }
395- uint64_t next_size = std::max (map_node->slots_ * 2 , uint64_t (kInitSize ));
417+ uint64_t next_size = std::max (map_node->NumSlots () * 2 , uint64_t (kInitSize ));
396418 next_size = std::min (next_size, uint64_t (kMaxSize ));
397- TVM_FFI_ICHECK_GT (next_size, map_node->slots_ );
419+ TVM_FFI_ICHECK_GT (next_size, map_node->NumSlots () );
398420 ObjectPtr<Object> new_map = CreateFromRange (next_size, map_node->begin (), map_node->end ());
399421 InsertMaybeReHash (std::move (kv), &new_map);
400422 *map = std::move (new_map);
@@ -525,6 +547,12 @@ class DenseMapObj : public MapObj {
525547 public:
526548 using MapObj::iterator;
527549
550+ /* !
551+ * \brief Return the number of usable slots for Dense layout (MSB clear => identity).
552+ * \return The number of usable slots
553+ */
554+ uint64_t NumSlots () const { return slots_; }
555+
528556 /* !
529557 * \brief Destroy the DenseMapObj
530558 */
@@ -558,7 +586,7 @@ class DenseMapObj : public MapObj {
558586 */
559587 void erase (const iterator& position) {
560588 uint64_t index = position.index ;
561- if (position.self != nullptr && index <= this ->slots_ ) {
589+ if (position.self != nullptr && index <= this ->NumSlots () ) {
562590 Erase (ListNode (index, this ));
563591 }
564592 }
@@ -817,7 +845,7 @@ class DenseMapObj : public MapObj {
817845 }
818846 /* ! \brief Clear the container to empty, release all entries and memory acquired */
819847 void Reset () {
820- uint64_t n_blocks = CalcNumBlocks (this ->slots_ );
848+ uint64_t n_blocks = CalcNumBlocks (this ->NumSlots () );
821849 for (uint64_t bi = 0 ; bi < n_blocks; ++bi) {
822850 uint8_t * meta_ptr = GetBlock (bi)->bytes ;
823851 ItemType* data_ptr = reinterpret_cast <ItemType*>(GetBlock (bi)->bytes + kBlockCap );
@@ -852,6 +880,8 @@ class DenseMapObj : public MapObj {
852880 */
853881 static ObjectPtr<DenseMapObj> Empty (uint32_t fib_shift, uint64_t n_slots) {
854882 TVM_FFI_ICHECK_GT (n_slots, uint64_t (SmallMapObj::kMaxSize ));
883+ // Ensure even slot count (power-of-two expected by callers; this guard
884+ // makes the method robust if a non-even value slips through).
855885 ObjectPtr<DenseMapObj> p = make_object<DenseMapObj>();
856886 uint64_t n_blocks = CalcNumBlocks (n_slots);
857887 Block* block = new Block[n_blocks];
@@ -860,7 +890,7 @@ class DenseMapObj : public MapObj {
860890 // in another shared-lib that may have different malloc/free behavior
861891 // it will still be safe.
862892 p->data_deleter_ = BlockDeleter;
863- p->slots_ = n_slots;
893+ p->SetSlotsAndDenseLayoutTag ( n_slots) ;
864894 p->size_ = 0 ;
865895 p->fib_shift_ = fib_shift;
866896 p->iter_list_head_ = kInvalidIndex ;
@@ -877,13 +907,13 @@ class DenseMapObj : public MapObj {
877907 */
878908 static ObjectPtr<DenseMapObj> CopyFrom (DenseMapObj* from) {
879909 ObjectPtr<DenseMapObj> p = make_object<DenseMapObj>();
880- uint64_t n_blocks = CalcNumBlocks (from->slots_ );
910+ uint64_t n_blocks = CalcNumBlocks (from->NumSlots () );
881911 p->data_ = new Block[n_blocks];
882912 // assign block deleter so even if we take re-alloc data
883913 // in another shared-lib that may have different malloc/free behavior
884914 // it will still be safe.
885915 p->data_deleter_ = BlockDeleter;
886- p->slots_ = from->slots_ ;
916+ p->SetSlotsAndDenseLayoutTag ( from->NumSlots ()) ;
887917 p->size_ = from->size_ ;
888918 p->fib_shift_ = from->fib_shift_ ;
889919 p->iter_list_head_ = from->iter_list_head_ ;
@@ -919,9 +949,9 @@ class DenseMapObj : public MapObj {
919949 map_node->IterListPushBack (iter);
920950 return ;
921951 }
922- TVM_FFI_ICHECK_GT ( map_node->slots_ , uint64_t (SmallMapObj:: kMaxSize ));
952+ TVM_FFI_ICHECK (! map_node->IsSmallMap ( ));
923953 // Otherwise, start rehash
924- ObjectPtr<Object> p = Empty (map_node->fib_shift_ - 1 , map_node->slots_ * 2 );
954+ ObjectPtr<Object> p = Empty (map_node->fib_shift_ - 1 , map_node->NumSlots () * 2 );
925955
926956 // need to insert in the same order as the original map
927957 for (uint64_t index = map_node->iter_list_head_ ; index != kInvalidIndex ;) {
@@ -947,7 +977,7 @@ class DenseMapObj : public MapObj {
947977 * \brief Check whether the hash table is full
948978 * \return A boolean indicating whether hash table is full
949979 */
950- bool IsFull () const { return size_ + 1 > slots_ * kMaxLoadFactor ; }
980+ bool IsFull () const { return size_ + 1 > NumSlots () * kMaxLoadFactor ; }
951981 /* !
952982 * \brief Increment the pointer
953983 * \param index The pointer to be incremented
@@ -1089,7 +1119,7 @@ class DenseMapObj : public MapObj {
10891119 }
10901120 // the probing will go to next position and round back to stay within the
10911121 // correct range of the slots
1092- index = (index + offset) % self->slots_ ;
1122+ index = (index + offset) % self->NumSlots () ;
10931123 block = self->GetBlock (index / kBlockCap );
10941124 return true ;
10951125 }
@@ -1110,7 +1140,7 @@ class DenseMapObj : public MapObj {
11101140 for (uint8_t idx = 1 ; idx < kNumJumpDists ; ++idx) {
11111141 // the probing will go to next position and round back to stay within the
11121142 // correct range of the slots
1113- ListNode candidate ((index + NextProbeLocation (idx)) % self->slots_ , self);
1143+ ListNode candidate ((index + NextProbeLocation (idx)) % self->NumSlots () , self);
11141144 if (candidate.IsEmpty ()) {
11151145 *jump = idx;
11161146 *result = candidate;
@@ -1164,14 +1194,23 @@ class DenseMapObj : public MapObj {
11641194 return kNextProbeLocation [index];
11651195 }
11661196 friend class MapObj ;
1197+
1198+ private:
1199+ /* !
1200+ * \brief Set the number of slots and attach tags bit.
1201+ * \param n The number of slots
1202+ */
1203+ void SetSlotsAndDenseLayoutTag (uint64_t n) {
1204+ TVM_FFI_ICHECK (((n & kSmallTagMask ) == 0ull )) << " DenseMap expects MSB clear" ;
1205+ slots_ = n;
1206+ }
11671207};
11681208
11691209#define TVM_FFI_DISPATCH_MAP (base, var, body ) \
11701210 { \
11711211 using TSmall = SmallMapObj*; \
11721212 using TDense = DenseMapObj*; \
1173- uint64_t slots = base->slots_ ; \
1174- if (slots <= SmallMapObj::kMaxSize ) { \
1213+ if (base->IsSmallMap ()) { \
11751214 TSmall var = static_cast <TSmall>(base); \
11761215 body; \
11771216 } else { \
@@ -1184,8 +1223,7 @@ class DenseMapObj : public MapObj {
11841223 { \
11851224 using TSmall = const SmallMapObj*; \
11861225 using TDense = const DenseMapObj*; \
1187- uint64_t slots = base->slots_ ; \
1188- if (slots <= SmallMapObj::kMaxSize ) { \
1226+ if (base->IsSmallMap ()) { \
11891227 TSmall var = static_cast <TSmall>(base); \
11901228 body; \
11911229 } else { \
@@ -1249,7 +1287,7 @@ inline void MapObj::erase(const MapObj::iterator& position) {
12491287inline ObjectPtr<MapObj> MapObj::Empty () { return SmallMapObj::Empty (); }
12501288
12511289inline ObjectPtr<MapObj> MapObj::CopyFrom (MapObj* from) {
1252- if (from->slots_ <= SmallMapObj:: kMaxSize ) {
1290+ if (from->IsSmallMap () ) {
12531291 return SmallMapObj::CopyFrom (static_cast <SmallMapObj*>(from));
12541292 } else {
12551293 return DenseMapObj::CopyFrom (static_cast <DenseMapObj*>(from));
@@ -1288,20 +1326,22 @@ inline ObjectPtr<Object> MapObj::CreateFromRange(IterType first, IterType last)
12881326}
12891327
12901328inline void MapObj::InsertMaybeReHash (KVType&& kv, ObjectPtr<Object>* map) {
1291- constexpr uint64_t kSmallMapMaxSize = SmallMapObj::kMaxSize ;
12921329 MapObj* base = static_cast <MapObj*>(map->get ());
12931330#if TVM_FFI_DEBUG_WITH_ABI_CHANGE
12941331 base->state_marker ++;
12951332#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE
1296- if (base->slots_ < kSmallMapMaxSize ) {
1297- SmallMapObj::InsertMaybeReHash (std::move (kv), map);
1298- } else if (base->slots_ == kSmallMapMaxSize ) {
1299- if (base->size_ < base->slots_ ) {
1333+ if (base->IsSmallMap ()) {
1334+ SmallMapObj* sm = static_cast <SmallMapObj*>(base);
1335+ if (sm->NumSlots () < SmallMapObj::kMaxSize ) {
13001336 SmallMapObj::InsertMaybeReHash (std::move (kv), map);
1301- } else {
1302- ObjectPtr<Object> new_map = MapObj::CreateFromRange (base->begin (), base->end ());
1303- DenseMapObj::InsertMaybeReHash (std::move (kv), &new_map);
1304- *map = std::move (new_map);
1337+ } else if (sm->NumSlots () == SmallMapObj::kMaxSize ) {
1338+ if (base->size_ < sm->NumSlots ()) {
1339+ SmallMapObj::InsertMaybeReHash (std::move (kv), map);
1340+ } else {
1341+ ObjectPtr<Object> new_map = MapObj::CreateFromRange (base->begin (), base->end ());
1342+ DenseMapObj::InsertMaybeReHash (std::move (kv), &new_map);
1343+ *map = std::move (new_map);
1344+ }
13051345 }
13061346 } else {
13071347 DenseMapObj::InsertMaybeReHash (std::move (kv), map);
0 commit comments