Skip to content

Commit 03e8a6b

Browse files
committed
[FFI][REFACTOR] Update Map ABI to enable flexible smallMap switch (apache#18200)
This PR updates the Map ABI to use MSB in slots_ to indicate SmallMap. The change would open doors for future changes to small map boundary switch.
1 parent 4be1af7 commit 03e8a6b

File tree

1 file changed

+68
-28
lines changed
  • include/tvm/ffi/container

1 file changed

+68
-28
lines changed

include/tvm/ffi/container/map.h

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,16 @@ class MapObj : public Object {
221221
uint64_t size_;
222222
/*! \brief number of slots */
223223
uint64_t slots_;
224+
/*!
225+
* \brief Small layout tag mask
226+
* \note The most significant bit is used to indicate the small map layout.
227+
*/
228+
static constexpr uint64_t kSmallTagMask = static_cast<uint64_t>(1) << 63;
229+
/*!
230+
* \brief Check if the map is a small map
231+
* \return True if the map is a small map
232+
*/
233+
bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; }
224234
/*!
225235
* \brief Optional data deleter when data is allocated separately
226236
* and its deletion is not managed by MapObj::deleter_.
@@ -242,6 +252,13 @@ class SmallMapObj : public MapObj,
242252
using MapObj::iterator;
243253
using MapObj::KVType;
244254

255+
// Return the number of usable slots for Small layout (mask off tag).
256+
/*!
257+
* \brief Return the number of usable slots for Small layout (mask off tag).
258+
* \return The number of usable slots
259+
*/
260+
uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; }
261+
245262
~SmallMapObj() {
246263
KVType* begin = static_cast<KVType*>(data_);
247264
for (uint64_t index = 0; index < size_; ++index) {
@@ -310,6 +327,11 @@ class SmallMapObj : public MapObj,
310327
void erase(const iterator& position) { Erase(position.index); }
311328

312329
private:
330+
/*!
331+
* \brief Set the number of slots and attach tags bit.
332+
* \param n The number of slots
333+
*/
334+
void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; }
313335
/*!
314336
* \brief Remove a position in SmallMapObj
315337
* \param index The position to be removed
@@ -344,7 +366,7 @@ class SmallMapObj : public MapObj,
344366
ObjectPtr<SmallMapObj> p = make_inplace_array_object<SmallMapObj, KVType>(n);
345367
p->data_ = p->AddressOf(0);
346368
p->size_ = 0;
347-
p->slots_ = n;
369+
p->SetSlotsAndSmallLayoutTag(n);
348370
return p;
349371
}
350372
/*!
@@ -386,15 +408,15 @@ class SmallMapObj : public MapObj,
386408
itr->second = kv.second;
387409
return;
388410
}
389-
if (map_node->size_ < map_node->slots_) {
411+
if (map_node->size_ < map_node->NumSlots()) {
390412
KVType* ptr = static_cast<KVType*>(map_node->data_) + map_node->size_;
391413
new (ptr) KVType(std::move(kv));
392414
++map_node->size_;
393415
return;
394416
}
395-
uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize));
417+
uint64_t next_size = std::max(map_node->NumSlots() * 2, uint64_t(kInitSize));
396418
next_size = std::min(next_size, uint64_t(kMaxSize));
397-
TVM_FFI_ICHECK_GT(next_size, map_node->slots_);
419+
TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots());
398420
ObjectPtr<Object> new_map = CreateFromRange(next_size, map_node->begin(), map_node->end());
399421
InsertMaybeReHash(std::move(kv), &new_map);
400422
*map = std::move(new_map);
@@ -525,6 +547,12 @@ class DenseMapObj : public MapObj {
525547
public:
526548
using MapObj::iterator;
527549

550+
/*!
551+
* \brief Return the number of usable slots for Dense layout (MSB clear => identity).
552+
* \return The number of usable slots
553+
*/
554+
uint64_t NumSlots() const { return slots_; }
555+
528556
/*!
529557
* \brief Destroy the DenseMapObj
530558
*/
@@ -558,7 +586,7 @@ class DenseMapObj : public MapObj {
558586
*/
559587
void erase(const iterator& position) {
560588
uint64_t index = position.index;
561-
if (position.self != nullptr && index <= this->slots_) {
589+
if (position.self != nullptr && index <= this->NumSlots()) {
562590
Erase(ListNode(index, this));
563591
}
564592
}
@@ -817,7 +845,7 @@ class DenseMapObj : public MapObj {
817845
}
818846
/*! \brief Clear the container to empty, release all entries and memory acquired */
819847
void Reset() {
820-
uint64_t n_blocks = CalcNumBlocks(this->slots_);
848+
uint64_t n_blocks = CalcNumBlocks(this->NumSlots());
821849
for (uint64_t bi = 0; bi < n_blocks; ++bi) {
822850
uint8_t* meta_ptr = GetBlock(bi)->bytes;
823851
ItemType* data_ptr = reinterpret_cast<ItemType*>(GetBlock(bi)->bytes + kBlockCap);
@@ -852,6 +880,8 @@ class DenseMapObj : public MapObj {
852880
*/
853881
static ObjectPtr<DenseMapObj> Empty(uint32_t fib_shift, uint64_t n_slots) {
854882
TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize));
883+
// Ensure even slot count (power-of-two expected by callers; this guard
884+
// makes the method robust if a non-even value slips through).
855885
ObjectPtr<DenseMapObj> p = make_object<DenseMapObj>();
856886
uint64_t n_blocks = CalcNumBlocks(n_slots);
857887
Block* block = new Block[n_blocks];
@@ -860,7 +890,7 @@ class DenseMapObj : public MapObj {
860890
// in another shared-lib that may have different malloc/free behavior
861891
// it will still be safe.
862892
p->data_deleter_ = BlockDeleter;
863-
p->slots_ = n_slots;
893+
p->SetSlotsAndDenseLayoutTag(n_slots);
864894
p->size_ = 0;
865895
p->fib_shift_ = fib_shift;
866896
p->iter_list_head_ = kInvalidIndex;
@@ -877,13 +907,13 @@ class DenseMapObj : public MapObj {
877907
*/
878908
static ObjectPtr<DenseMapObj> CopyFrom(DenseMapObj* from) {
879909
ObjectPtr<DenseMapObj> p = make_object<DenseMapObj>();
880-
uint64_t n_blocks = CalcNumBlocks(from->slots_);
910+
uint64_t n_blocks = CalcNumBlocks(from->NumSlots());
881911
p->data_ = new Block[n_blocks];
882912
// assign block deleter so even if we take re-alloc data
883913
// in another shared-lib that may have different malloc/free behavior
884914
// it will still be safe.
885915
p->data_deleter_ = BlockDeleter;
886-
p->slots_ = from->slots_;
916+
p->SetSlotsAndDenseLayoutTag(from->NumSlots());
887917
p->size_ = from->size_;
888918
p->fib_shift_ = from->fib_shift_;
889919
p->iter_list_head_ = from->iter_list_head_;
@@ -919,9 +949,9 @@ class DenseMapObj : public MapObj {
919949
map_node->IterListPushBack(iter);
920950
return;
921951
}
922-
TVM_FFI_ICHECK_GT(map_node->slots_, uint64_t(SmallMapObj::kMaxSize));
952+
TVM_FFI_ICHECK(!map_node->IsSmallMap());
923953
// Otherwise, start rehash
924-
ObjectPtr<Object> p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2);
954+
ObjectPtr<Object> p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2);
925955

926956
// need to insert in the same order as the original map
927957
for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) {
@@ -947,7 +977,7 @@ class DenseMapObj : public MapObj {
947977
* \brief Check whether the hash table is full
948978
* \return A boolean indicating whether hash table is full
949979
*/
950-
bool IsFull() const { return size_ + 1 > slots_ * kMaxLoadFactor; }
980+
bool IsFull() const { return size_ + 1 > NumSlots() * kMaxLoadFactor; }
951981
/*!
952982
* \brief Increment the pointer
953983
* \param index The pointer to be incremented
@@ -1089,7 +1119,7 @@ class DenseMapObj : public MapObj {
10891119
}
10901120
// the probing will go to next position and round back to stay within the
10911121
// correct range of the slots
1092-
index = (index + offset) % self->slots_;
1122+
index = (index + offset) % self->NumSlots();
10931123
block = self->GetBlock(index / kBlockCap);
10941124
return true;
10951125
}
@@ -1110,7 +1140,7 @@ class DenseMapObj : public MapObj {
11101140
for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) {
11111141
// the probing will go to next position and round back to stay within the
11121142
// correct range of the slots
1113-
ListNode candidate((index + NextProbeLocation(idx)) % self->slots_, self);
1143+
ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self);
11141144
if (candidate.IsEmpty()) {
11151145
*jump = idx;
11161146
*result = candidate;
@@ -1164,14 +1194,23 @@ class DenseMapObj : public MapObj {
11641194
return kNextProbeLocation[index];
11651195
}
11661196
friend class MapObj;
1197+
1198+
private:
1199+
/*!
1200+
* \brief Set the number of slots and attach tags bit.
1201+
* \param n The number of slots
1202+
*/
1203+
void SetSlotsAndDenseLayoutTag(uint64_t n) {
1204+
TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear";
1205+
slots_ = n;
1206+
}
11671207
};
11681208

11691209
#define TVM_FFI_DISPATCH_MAP(base, var, body) \
11701210
{ \
11711211
using TSmall = SmallMapObj*; \
11721212
using TDense = DenseMapObj*; \
1173-
uint64_t slots = base->slots_; \
1174-
if (slots <= SmallMapObj::kMaxSize) { \
1213+
if (base->IsSmallMap()) { \
11751214
TSmall var = static_cast<TSmall>(base); \
11761215
body; \
11771216
} else { \
@@ -1184,8 +1223,7 @@ class DenseMapObj : public MapObj {
11841223
{ \
11851224
using TSmall = const SmallMapObj*; \
11861225
using TDense = const DenseMapObj*; \
1187-
uint64_t slots = base->slots_; \
1188-
if (slots <= SmallMapObj::kMaxSize) { \
1226+
if (base->IsSmallMap()) { \
11891227
TSmall var = static_cast<TSmall>(base); \
11901228
body; \
11911229
} else { \
@@ -1249,7 +1287,7 @@ inline void MapObj::erase(const MapObj::iterator& position) {
12491287
inline ObjectPtr<MapObj> MapObj::Empty() { return SmallMapObj::Empty(); }
12501288

12511289
inline ObjectPtr<MapObj> MapObj::CopyFrom(MapObj* from) {
1252-
if (from->slots_ <= SmallMapObj::kMaxSize) {
1290+
if (from->IsSmallMap()) {
12531291
return SmallMapObj::CopyFrom(static_cast<SmallMapObj*>(from));
12541292
} else {
12551293
return DenseMapObj::CopyFrom(static_cast<DenseMapObj*>(from));
@@ -1288,20 +1326,22 @@ inline ObjectPtr<Object> MapObj::CreateFromRange(IterType first, IterType last)
12881326
}
12891327

12901328
inline void MapObj::InsertMaybeReHash(KVType&& kv, ObjectPtr<Object>* map) {
1291-
constexpr uint64_t kSmallMapMaxSize = SmallMapObj::kMaxSize;
12921329
MapObj* base = static_cast<MapObj*>(map->get());
12931330
#if TVM_FFI_DEBUG_WITH_ABI_CHANGE
12941331
base->state_marker++;
12951332
#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE
1296-
if (base->slots_ < kSmallMapMaxSize) {
1297-
SmallMapObj::InsertMaybeReHash(std::move(kv), map);
1298-
} else if (base->slots_ == kSmallMapMaxSize) {
1299-
if (base->size_ < base->slots_) {
1333+
if (base->IsSmallMap()) {
1334+
SmallMapObj* sm = static_cast<SmallMapObj*>(base);
1335+
if (sm->NumSlots() < SmallMapObj::kMaxSize) {
13001336
SmallMapObj::InsertMaybeReHash(std::move(kv), map);
1301-
} else {
1302-
ObjectPtr<Object> new_map = MapObj::CreateFromRange(base->begin(), base->end());
1303-
DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map);
1304-
*map = std::move(new_map);
1337+
} else if (sm->NumSlots() == SmallMapObj::kMaxSize) {
1338+
if (base->size_ < sm->NumSlots()) {
1339+
SmallMapObj::InsertMaybeReHash(std::move(kv), map);
1340+
} else {
1341+
ObjectPtr<Object> new_map = MapObj::CreateFromRange(base->begin(), base->end());
1342+
DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map);
1343+
*map = std::move(new_map);
1344+
}
13051345
}
13061346
} else {
13071347
DenseMapObj::InsertMaybeReHash(std::move(kv), map);

0 commit comments

Comments
 (0)