diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h index 977dbfbaaaa1..4c76a3b0ad4f 100644 --- a/include/tvm/runtime/container/map.h +++ b/include/tvm/runtime/container/map.h @@ -38,6 +38,13 @@ namespace tvm { namespace runtime { +#if TVM_LOG_DEBUG +#define TVM_MAP_FAIL_IF_CHANGED() \ + ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; +#else +#define TVM_MAP_FAIL_IF_CHANGED() +#endif // TVM_LOG_DEBUG + #if (USE_FALLBACK_STL_MAP != 0) /*! \brief Shared content of all specializations of hash map */ @@ -233,10 +240,15 @@ class MapNode : public Object { using value_type = KVType; using pointer = KVType*; using reference = KVType&; - /*! \brief Default constructor */ +/*! \brief Default constructor */ +#if TVM_LOG_DEBUG + iterator() : state_marker(0), index(0), self(nullptr) {} +#else iterator() : index(0), self(nullptr) {} +#endif // TVM_LOG_DEBUG /*! \brief Compare iterators */ bool operator==(const iterator& other) const { + TVM_MAP_FAIL_IF_CHANGED() return index == other.index && self == other.self; } /*! \brief Compare iterators */ @@ -244,27 +256,39 @@ class MapNode : public Object { /*! \brief De-reference iterators */ pointer operator->() const; /*! \brief De-reference iterators */ - reference operator*() const { return *((*this).operator->()); } + reference operator*() const { + TVM_MAP_FAIL_IF_CHANGED() + return *((*this).operator->()); + } /*! \brief Prefix self increment, e.g. ++iter */ iterator& operator++(); /*! \brief Prefix self decrement, e.g. --iter */ iterator& operator--(); /*! \brief Suffix self increment */ iterator operator++(int) { + TVM_MAP_FAIL_IF_CHANGED() iterator copy = *this; ++(*this); return copy; } /*! \brief Suffix self decrement */ iterator operator--(int) { + TVM_MAP_FAIL_IF_CHANGED() iterator copy = *this; --(*this); return copy; } protected: +#if TVM_LOG_DEBUG + uint64_t state_marker; /*! \brief Construct by value */ + iterator(uint64_t index, const MapNode* self) + : state_marker(self->state_marker), index(index), self(self) {} + +#else iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} +#endif // TVM_LOG_DEBUG /*! \brief The position on the array */ uint64_t index; /*! \brief The container it points to */ @@ -280,6 +304,9 @@ class MapNode : public Object { static inline ObjectPtr Empty(); protected: +#if TVM_LOG_DEBUG + uint64_t state_marker; +#endif // TVM_LOG_DEBUG /*! * \brief Create the map using contents from the given iterators. * \param first Begin of iterator @@ -1118,10 +1145,12 @@ class DenseMapNode : public MapNode { } inline MapNode::iterator::pointer MapNode::iterator::operator->() const { + TVM_MAP_FAIL_IF_CHANGED() TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); } inline MapNode::iterator& MapNode::iterator::operator++() { + TVM_MAP_FAIL_IF_CHANGED() TVM_DISPATCH_MAP_CONST(self, p, { index = p->IncItr(index); return *this; @@ -1129,6 +1158,7 @@ inline MapNode::iterator& MapNode::iterator::operator++() { } inline MapNode::iterator& MapNode::iterator::operator--() { + TVM_MAP_FAIL_IF_CHANGED() TVM_DISPATCH_MAP_CONST(self, p, { index = p->DecItr(index); return *this; @@ -1200,6 +1230,9 @@ inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; MapNode* base = static_cast(map->get()); +#if TVM_LOG_DEBUG + base->state_marker++; +#endif // TVM_LOG_DEBUG if (base->slots_ < kSmallMapMaxSize) { SmallMapNode::InsertMaybeReHash(kv, map); } else if (base->slots_ == kSmallMapMaxSize) { diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 019fde069878..32ec346c8796 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -380,6 +380,21 @@ TEST(Map, Erase) { } } +#if TVM_LOG_DEBUG +TEST(Map, Race) { + using namespace tvm::runtime; + Map m; + + m.Set(1, 1); + Map::iterator it = m.begin(); + EXPECT_NO_THROW({ auto& kv = *it; }); + + m.Set(2, 2); + // changed. iterator should be re-obtained + EXPECT_ANY_THROW({ auto& kv = *it; }); +} +#endif // TVM_LOG_DEBUG + TEST(String, MoveFromStd) { using namespace std; string source = "this is a string";