Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions include/tvm/runtime/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
namespace tvm {
namespace runtime {

#if TVM_LOG_DEBUG
#define TVM_MAP_FAIL_IF_CHANGED() \
ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map";
#else
#define TVM_MAP_FAIL_IF_CHANGED()
#endif // TVM_LOG_DEBUG

#if (USE_FALLBACK_STL_MAP != 0)

/*! \brief Shared content of all specializations of hash map */
Expand Down Expand Up @@ -233,38 +240,55 @@ class MapNode : public Object {
using value_type = KVType;
using pointer = KVType*;
using reference = KVType&;
/*! \brief Default constructor */
/*! \brief Default constructor */
#if TVM_LOG_DEBUG
iterator() : state_marker(0), index(0), self(nullptr) {}
#else
iterator() : index(0), self(nullptr) {}
#endif // TVM_LOG_DEBUG
/*! \brief Compare iterators */
bool operator==(const iterator& other) const {
TVM_MAP_FAIL_IF_CHANGED()
return index == other.index && self == other.self;
}
/*! \brief Compare iterators */
bool operator!=(const iterator& other) const { return !(*this == other); }
/*! \brief De-reference iterators */
pointer operator->() const;
/*! \brief De-reference iterators */
reference operator*() const { return *((*this).operator->()); }
reference operator*() const {
TVM_MAP_FAIL_IF_CHANGED()
return *((*this).operator->());
}
/*! \brief Prefix self increment, e.g. ++iter */
iterator& operator++();
/*! \brief Prefix self decrement, e.g. --iter */
iterator& operator--();
/*! \brief Suffix self increment */
iterator operator++(int) {
TVM_MAP_FAIL_IF_CHANGED()
iterator copy = *this;
++(*this);
return copy;
}
/*! \brief Suffix self decrement */
iterator operator--(int) {
TVM_MAP_FAIL_IF_CHANGED()
iterator copy = *this;
--(*this);
return copy;
}

protected:
#if TVM_LOG_DEBUG
uint64_t state_marker;
/*! \brief Construct by value */
iterator(uint64_t index, const MapNode* self)
: state_marker(self->state_marker), index(index), self(self) {}

#else
iterator(uint64_t index, const MapNode* self) : index(index), self(self) {}
#endif // TVM_LOG_DEBUG
/*! \brief The position on the array */
uint64_t index;
/*! \brief The container it points to */
Expand All @@ -280,6 +304,9 @@ class MapNode : public Object {
static inline ObjectPtr<MapNode> Empty();

protected:
#if TVM_LOG_DEBUG
uint64_t state_marker;
#endif // TVM_LOG_DEBUG
/*!
* \brief Create the map using contents from the given iterators.
* \param first Begin of iterator
Expand Down Expand Up @@ -1118,17 +1145,20 @@ class DenseMapNode : public MapNode {
}

inline MapNode::iterator::pointer MapNode::iterator::operator->() const {
TVM_MAP_FAIL_IF_CHANGED()
TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); });
}

inline MapNode::iterator& MapNode::iterator::operator++() {
TVM_MAP_FAIL_IF_CHANGED()
TVM_DISPATCH_MAP_CONST(self, p, {
index = p->IncItr(index);
return *this;
});
}

inline MapNode::iterator& MapNode::iterator::operator--() {
TVM_MAP_FAIL_IF_CHANGED()
TVM_DISPATCH_MAP_CONST(self, p, {
index = p->DecItr(index);
return *this;
Expand Down Expand Up @@ -1200,6 +1230,9 @@ inline ObjectPtr<Object> MapNode::CreateFromRange(IterType first, IterType last)
inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize;
MapNode* base = static_cast<MapNode*>(map->get());
#if TVM_LOG_DEBUG
base->state_marker++;
#endif // TVM_LOG_DEBUG
if (base->slots_ < kSmallMapMaxSize) {
SmallMapNode::InsertMaybeReHash(kv, map);
} else if (base->slots_ == kSmallMapMaxSize) {
Expand Down
15 changes: 15 additions & 0 deletions tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,21 @@ TEST(Map, Erase) {
}
}

#if TVM_LOG_DEBUG
TEST(Map, Race) {
using namespace tvm::runtime;
Map<Integer, Integer> m;

m.Set(1, 1);
Map<tvm::Integer, tvm::Integer>::iterator it = m.begin();
EXPECT_NO_THROW({ auto& kv = *it; });

m.Set(2, 2);
// changed. iterator should be re-obtained
EXPECT_ANY_THROW({ auto& kv = *it; });
}
#endif // TVM_LOG_DEBUG

TEST(String, MoveFromStd) {
using namespace std;
string source = "this is a string";
Expand Down